-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement more additive models #161
Changes from all commits
18bc8f8
7b2bd9a
5feeb13
35649d5
bc5651d
21dfbca
01f9d49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,19 +20,36 @@ | |
__all__ = [ | ||
'Band', | ||
'BandEp', | ||
'BentPL', | ||
'Blackbody', | ||
'BlackbodyRad', | ||
'BrokenPL', | ||
'Compt', | ||
'CutoffPL', | ||
'Gauss', | ||
'Lorentz', | ||
'OTTB', | ||
'OTTS', | ||
'PLEnFlux', | ||
'PLPhFlux', | ||
'PowerLaw', | ||
'SmoothlyBrokenPL', | ||
] | ||
|
||
|
||
def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray: | ||
cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0)) | ||
|
||
one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0) | ||
f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha | ||
f1 = f1[1:] - f1[:-1] | ||
|
||
f2 = jnp.log(egrid) | ||
f2 = f2[1:] - f2[:-1] | ||
|
||
return jnp.where(cond[:-1], f1, f2) | ||
|
||
|
||
class Band(NumIntAdditive): | ||
r"""Gamma-ray burst continuum developed by Band et al. (1993) [1]_. | ||
|
||
|
@@ -183,6 +200,42 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: | |
return K * jnp.exp(log) | ||
|
||
|
||
class BentPL(NumIntAdditive): | ||
r"""Bent power law. | ||
|
||
.. math:: | ||
N(E) = 2K \left[ | ||
1 + \left(\frac{E}{E_\mathrm{b}}\right)^\alpha | ||
\right]^{-1} | ||
|
||
Parameters | ||
---------- | ||
alpha : Parameter, optional | ||
The power law photon index :math:`\alpha`, dimensionless. | ||
Eb : Parameter, optional | ||
The break energy :math:`E_\mathrm{b}`, in units of keV. | ||
K : Parameter, optional | ||
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹. | ||
latex : str, optional | ||
:math:`\LaTeX` format of the component. Defaults to class name. | ||
method : {'trapz', 'simpson'}, optional | ||
Numerical integration method. Defaults to 'trapz'. | ||
""" | ||
|
||
_config = ( | ||
ParamConfig('alpha', 'alpha', '', 1.5, -3.0, 10.0), | ||
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6), | ||
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10), | ||
) | ||
|
||
@staticmethod | ||
def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: | ||
alpha = params['alpha'] | ||
Eb = params['Eb'] | ||
K = params['K'] | ||
return K * 2.0 / (1.0 + jnp.power(egrid / Eb, alpha)) | ||
|
||
|
||
class Blackbody(NumIntAdditive): | ||
r"""Blackbody function. | ||
|
||
|
@@ -287,14 +340,60 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: | |
|
||
|
||
class BrokenPL(AnaIntAdditive): | ||
pass | ||
r"""Broken power law. | ||
|
||
.. math:: | ||
N(E) = K | ||
\begin{cases} | ||
\left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_1}, | ||
&\text{if } E \le E_\mathrm{b}, | ||
\\\\ | ||
\left(\frac{E}{E_\mathrm{b}}\right)^{-\alpha_2}, | ||
&\text{otherwise}. | ||
\end{cases} | ||
|
||
class DoubleBrokenPL(AnaIntAdditive): | ||
pass | ||
Parameters | ||
---------- | ||
alpha1 : Parameter, optional | ||
The low-energy power law photon index :math:`\alpha_1`, dimensionless. | ||
alpha2 : Parameter, optional | ||
The high-energy power law photon index :math:`\alpha_2`, dimensionless. | ||
Eb : Parameter, optional | ||
The break energy :math:`E_\mathrm{b}`, in units of keV. | ||
K : Parameter, optional | ||
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹. | ||
latex : str, optional | ||
:math:`\LaTeX` format of the component. Defaults to class name. | ||
method : {'trapz', 'simpson'}, optional | ||
Numerical integration method. Defaults to 'trapz'. | ||
""" | ||
|
||
_config = ( | ||
ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0), | ||
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6), | ||
ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0), | ||
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10), | ||
) | ||
|
||
class SmoothlyBrokenPL(NumIntAdditive): | ||
@staticmethod | ||
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (complexity): Consider refactoring the integral calculation in Consider refactoring the integral so that you avoid manual index‐finding and in-place updates. For example, instead of building a mask, extracting an index with `jnp.flatnonzero`, and then patching that bin with a custom “cross‐bin” calculation, you can compute the definite integrals in a vectorized, piecewise manner. One approach is:
```python
@staticmethod
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
alpha1 = params['alpha1']
Eb = params['Eb']
alpha2 = params['alpha2']
K = params['K']
# Define the antiderivative in terms of the normalized energy E/Eb.
def antideriv(E, alpha):
return _powerlaw_integral(E / Eb, alpha)
# Bin edges
E_low = egrid[:-1]
E_high = egrid[1:]
# Determine which bins are fully below, fully above, or cross the break.
lower_bins = E_high <= Eb
upper_bins = E_low >= Eb
cross_bins = ~(lower_bins | upper_bins)
# Compute integrals for bins entirely in one regime:
int_lower = antideriv(E_high[lower_bins], alpha1) - antideriv(E_low[lower_bins], alpha1)
int_upper = antideriv(E_high[upper_bins], alpha2) - antideriv(E_low[upper_bins], alpha2)
# For bins that cross the break, split the integration at Eb.
int_cross = (antideriv(Eb, alpha1) - antideriv(E_low[cross_bins], alpha1) +
antideriv(E_high[cross_bins], alpha2) - antideriv(Eb, alpha2))
# Allocate and fill the result.
result = jnp.empty(egrid.size - 1)
result = result.at[lower_bins].set(int_lower)
result = result.at[upper_bins].set(int_upper)
result = result.at[cross_bins].set(int_cross)
return K * result This approach removes the need for manual index extraction and in-place patching while preserving all functionality. Adjust the antiderivative logic if needed based on your implementation of
|
||
alpha1 = params['alpha1'] | ||
Eb = params['Eb'] | ||
alpha2 = params['alpha2'] | ||
K = params['K'] | ||
mask = egrid[:-1] <= Eb | ||
egrid_ = egrid / Eb | ||
p1 = _powerlaw_integral(egrid_, alpha1) | ||
p2 = _powerlaw_integral(egrid_, alpha2) | ||
f = jnp.where(mask, p1, p2) | ||
idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Potential normalization mismatch in BrokenPL.integral endpoints. In BrokenPL.integral, the energy grid is normalized (egrid_ = egrid / Eb) so that the break occurs at 1, yet when computing pb1 and pb2 the code uses Eb directly. Verify that the endpoints for _powerlaw_integral are correctly scaled (possibly using 1.0 rather than Eb) to ensure consistency. |
||
pb1 = _powerlaw_integral(jnp.hstack([egrid_[idx], Eb]), alpha1) | ||
pb2 = _powerlaw_integral(jnp.hstack([Eb, egrid_[idx + 1]]), alpha2) | ||
f.at[idx].set(pb1[0] + pb2[0]) | ||
return K * f | ||
|
||
|
||
class DoubleBrokenPL(AnaIntAdditive): | ||
pass | ||
|
||
|
||
|
@@ -421,8 +520,45 @@ class LogParabola(NumIntAdditive): | |
pass | ||
|
||
|
||
class Lorentz(NumIntAdditive): | ||
pass | ||
class Lorentz(AnaIntAdditive): | ||
r"""Lorentzian line profile. | ||
|
||
.. math:: | ||
N(E) = \frac{K/\mathcal{F}}{(E - E_\mathrm{l})^2 + (\sigma/2)^2}, | ||
|
||
where | ||
|
||
.. math:: | ||
\mathcal{F}=\int_0^{+\infty} | ||
\frac{1}{(E - E_\mathrm{l})^2 + (\sigma/2)^2} \ \mathrm{d}E. | ||
|
||
Parameters | ||
---------- | ||
El : Parameter, optional | ||
The line energy :math:`E_\mathrm{l}`, in units of keV. | ||
sigma : Parameter, optional | ||
The FWHM line width :math:`\sigma`, in units of keV. | ||
K : Parameter, optional | ||
The total photon flux :math:`K` of the line, in units of ph cm⁻² s⁻¹. | ||
latex : str, optional | ||
:math:`\LaTeX` format of the component. Defaults to class name. | ||
""" | ||
|
||
_config = ( | ||
ParamConfig('El', r'E_\mathrm{l}', 'keV', 6.5, 1e-6, 1e6), | ||
ParamConfig('sigma', r'\sigma', 'keV', 0.1, 1e-3, 20), | ||
ParamConfig('K', 'K', 'ph cm^-2 s^-1', 1.0, 1e-10, 1e10), | ||
) | ||
|
||
@staticmethod | ||
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray: | ||
El = params['El'] | ||
sigma = params['sigma'] | ||
K = params['K'] | ||
x = 2.0 * (egrid - El) / sigma | ||
integral = jnp.arctan(x) | ||
norm = 2.0 * K / (jnp.pi + 2 * jnp.arctan(2 * El / sigma)) | ||
return norm * (integral[1:] - integral[:-1]) | ||
|
||
|
||
class OTTB(NumIntAdditive): | ||
|
@@ -493,19 +629,6 @@ def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: | |
return K * jnp.exp(-jnp.power(egrid / Ec, 1.0 / 3.0)) | ||
|
||
|
||
def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray: | ||
cond = jnp.full(len(egrid), jnp.not_equal(alpha, 1.0)) | ||
|
||
one_minus_alpha = jnp.where(cond, 1.0 - alpha, 1.0) | ||
f1 = jnp.power(egrid, one_minus_alpha) / one_minus_alpha | ||
f1 = f1[1:] - f1[:-1] | ||
|
||
f2 = jnp.log(egrid) | ||
f2 = f2[1:] - f2[:-1] | ||
|
||
return jnp.where(cond[:-1], f1, f2) | ||
|
||
|
||
class PowerLaw(AnaIntAdditive): | ||
r"""Power law function. | ||
|
||
|
@@ -677,3 +800,62 @@ class PLEnFlux(PLFluxNorm): | |
log=True, | ||
), | ||
) | ||
|
||
|
||
class SmoothlyBrokenPL(NumIntAdditive): | ||
r"""Smoothly broken power law. | ||
|
||
.. math:: | ||
N(E) = K | ||
\left( \frac{E}{E_0} \right) ^ {-\alpha_1} | ||
\left\{ | ||
2\left[ | ||
1 + \left(\frac{E}{E_\mathrm{b}} | ||
\right)^{\left( \alpha_2 - \alpha_1 \right) / \rho} | ||
\right] | ||
\right\}^{-\rho}, | ||
|
||
where :math:`E_0` is the pivot energy fixed at 1 keV. | ||
|
||
Parameters | ||
---------- | ||
alpha1 : Parameter, optional | ||
The low-energy power law index :math:`\alpha_1`, dimensionless. | ||
Eb : Parameter, optional | ||
The break energy :math:`E_\mathrm{b}`, in units of keV. | ||
alpha2 : Parameter, optional | ||
The high-energy power law index :math:`\alpha_2`, dimensionless. | ||
rho : Parameter, optional | ||
The smoothness parameter :math:`\rho`, dimensionless. | ||
K : Parameter, optional | ||
The amplitude :math:`K`, in units of ph cm⁻² s⁻¹ keV⁻¹. | ||
latex : str, optional | ||
:math:`\LaTeX` format of the component. Defaults to class name. | ||
method : {'trapz', 'simpson'}, optional | ||
Numerical integration method. Defaults to 'trapz'. | ||
""" | ||
|
||
_config = ( | ||
ParamConfig('alpha1', 'alpha_1', '', 1.0, -3.0, 10.0), | ||
ParamConfig('Eb', r'E_\mathrm{b}', 'keV', 5.0, 1e-2, 1e6), | ||
ParamConfig('alpha2', 'alpha_2', '', 2.0, -3.0, 10.0), | ||
ParamConfig('rho', r'\rho', '', 1.0, 1e-2, 1e2, fixed=True), | ||
ParamConfig('K', 'K', 'ph cm^-2 s^-1 keV^-1', 1.0, 1e-10, 1e10), | ||
) | ||
|
||
@staticmethod | ||
def continuum(egrid: JAXArray, params: NameValMapping) -> JAXArray: | ||
alpha1 = params['alpha1'] | ||
Eb = params['Eb'] | ||
alpha2 = params['alpha2'] | ||
rho = params['rho'] | ||
K = params['K'] | ||
|
||
e = egrid / Eb | ||
x = (alpha2 - alpha1) / rho * jnp.log(e) | ||
|
||
threshold = 30 | ||
alpha = jnp.where(x > threshold, alpha2, alpha1) | ||
r = jnp.where(jnp.abs(x) > threshold, 0.5, 0.5 * (1.0 + jnp.exp(x))) | ||
|
||
return K * jnp.power(e, -alpha) * jnp.power(r, -rho) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider refactoring the integration logic in
_powerlaw_integral
andBrokenPL.integral
into smaller, purpose-specific helper functions to improve readability and maintainability.Consider extracting parts of the integration logic into small, purpose‐specific helper functions to simplify the conditional handling and array slicing. For example, in
_powerlaw_integral
you can separate the power‐law and logarithmic parts:Similarly, in the
integral
method forBrokenPL
, isolate the break-region adjustment:These small refactorings keep functionality intact while isolating complex operations, making the code easier to read and maintain without reverting any changes.