-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement more additive models #161
Conversation
Reviewer's Guide by SourceryThis pull request introduces four new additive models (Bent Power Law, Broken Power Law, Smoothly Broken Power Law, and Lorentz) to the Updated class diagram for additive modelsclassDiagram
class NumIntAdditive {
+continuum(egrid: JAXArray, params: NameValMapping): JAXArray
}
class AnaIntAdditive {
+integral(egrid: JAXArray, params: NameValMapping): JAXArray
}
class Band extends NumIntAdditive
class BandEp extends NumIntAdditive
class BentPL extends NumIntAdditive {
-_config
+continuum(egrid: JAXArray, params: NameValMapping): JAXArray
}
class Blackbody extends NumIntAdditive
class BlackbodyRad extends NumIntAdditive
class BrokenPL extends AnaIntAdditive {
-_config
+integral(egrid: JAXArray, params: NameValMapping): JAXArray
}
class Compt extends NumIntAdditive
class CutoffPL extends NumIntAdditive
class Gauss extends NumIntAdditive
class Lorentz extends AnaIntAdditive {
-_config
+integral(egrid: JAXArray, params: NameValMapping): JAXArray
}
class OTTB extends NumIntAdditive
class OTTS extends NumIntAdditive
class PLEnFlux extends NumIntAdditive
class PLPhFlux extends NumIntAdditive
class PowerLaw extends AnaIntAdditive
class SmoothlyBrokenPL extends NumIntAdditive {
-_config
+continuum(egrid: JAXArray, params: NameValMapping): JAXArray
}
NumIntAdditive <|-- Band
NumIntAdditive <|-- BandEp
NumIntAdditive <|-- BentPL
NumIntAdditive <|-- Blackbody
NumIntAdditive <|-- BlackbodyRad
AnaIntAdditive <|-- BrokenPL
NumIntAdditive <|-- Compt
NumIntAdditive <|-- CutoffPL
NumIntAdditive <|-- Gauss
AnaIntAdditive <|-- Lorentz
NumIntAdditive <|-- OTTB
NumIntAdditive <|-- OTTS
NumIntAdditive <|-- PLEnFlux
NumIntAdditive <|-- PLPhFlux
AnaIntAdditive <|-- PowerLaw
NumIntAdditive <|-- SmoothlyBrokenPL
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @wcxve - I've reviewed your changes - here's some feedback:
Overall Comments:
- It would be good to include a unit test for each of the new models.
- Consider refactoring common logic into helper functions to reduce code duplication between models.
Here's what I looked at during the review
- 🟡 General issues: 1 issue found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
src/elisa/models/add.py
Outdated
idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1] | ||
pb1 = _powerlaw_integral(jnp.hstack([egrid_[idx], Eb]), alpha1) | ||
pb2 = _powerlaw_integral(jnp.hstack([Eb, egrid_[idx + 1]]), alpha2) | ||
f.at[idx].set(sum(pb1 + pb2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (performance): Consider using jnp.sum rather than the built-in sum.
Using Python’s built-in sum on arrays can be unpredictable in a JAX context. Switching to jnp.sum or a similar JAX-native operation might improve clarity and performance.
f.at[idx].set(sum(pb1 + pb2)) | |
f.at[idx].set(jnp.sum(pb1 + pb2)) |
|
||
@staticmethod | ||
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider refactoring the integral calculation in BrokenPL
to use vectorized operations and piecewise integration, avoiding manual index finding and in-place updates.
Consider refactoring the integral so that you avoid manual index‐finding and in-place updates. For example, instead of building a mask, extracting an index with `jnp.flatnonzero`, and then patching that bin with a custom “cross‐bin” calculation, you can compute the definite integrals in a vectorized, piecewise manner. One approach is:
```python
@staticmethod
def integral(egrid: JAXArray, params: NameValMapping) -> JAXArray:
alpha1 = params['alpha1']
Eb = params['Eb']
alpha2 = params['alpha2']
K = params['K']
# Define the antiderivative in terms of the normalized energy E/Eb.
def antideriv(E, alpha):
return _powerlaw_integral(E / Eb, alpha)
# Bin edges
E_low = egrid[:-1]
E_high = egrid[1:]
# Determine which bins are fully below, fully above, or cross the break.
lower_bins = E_high <= Eb
upper_bins = E_low >= Eb
cross_bins = ~(lower_bins | upper_bins)
# Compute integrals for bins entirely in one regime:
int_lower = antideriv(E_high[lower_bins], alpha1) - antideriv(E_low[lower_bins], alpha1)
int_upper = antideriv(E_high[upper_bins], alpha2) - antideriv(E_low[upper_bins], alpha2)
# For bins that cross the break, split the integration at Eb.
int_cross = (antideriv(Eb, alpha1) - antideriv(E_low[cross_bins], alpha1) +
antideriv(E_high[cross_bins], alpha2) - antideriv(Eb, alpha2))
# Allocate and fill the result.
result = jnp.empty(egrid.size - 1)
result = result.at[lower_bins].set(int_lower)
result = result.at[upper_bins].set(int_upper)
result = result.at[cross_bins].set(int_cross)
return K * result
This approach removes the need for manual index extraction and in-place patching while preserving all functionality. Adjust the antiderivative logic if needed based on your implementation of _powerlaw_integral
.
@sourcery-ai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @wcxve - I've reviewed your changes - here's some feedback:
Overall Comments:
- Consider adding a short description of the parameters to the docstrings of the new models.
- It might be good to have a consistent style for the parameter names (e.g.
El
vsE_l
).
Here's what I looked at during the review
- 🟡 General issues: 1 issue found
- 🟢 Security: all looks good
- 🟡 Testing: 1 issue found
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
p1 = _powerlaw_integral(egrid_, alpha1) | ||
p2 = _powerlaw_integral(egrid_, alpha2) | ||
f = jnp.where(mask, p1, p2) | ||
idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Potential normalization mismatch in BrokenPL.integral endpoints.
In BrokenPL.integral, the energy grid is normalized (egrid_ = egrid / Eb) so that the break occurs at 1, yet when computing pb1 and pb2 the code uses Eb directly. Verify that the endpoints for _powerlaw_integral are correctly scaled (possibly using 1.0 rather than Eb) to ensure consistency.
tests/test_model.py
Outdated
egrid = np.geomspace(1e-4, 1e9, 1301) | ||
assert not np.any(np.isnan(model(**kwargs).compile().eval(egrid))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): Consider testing for infinite values.
Besides NaN values, it's also important to check for infinite values, as they can also indicate issues with the models. Suggest adding np.isinf
checks as well.
] | ||
|
||
|
||
def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider refactoring the integration logic in _powerlaw_integral
and BrokenPL.integral
into smaller, purpose-specific helper functions to improve readability and maintainability.
Consider extracting parts of the integration logic into small, purpose‐specific helper functions to simplify the conditional handling and array slicing. For example, in _powerlaw_integral
you can separate the power‐law and logarithmic parts:
def _integral_non_unity(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
one_minus_alpha = 1.0 - alpha
f = jnp.power(egrid, one_minus_alpha) / one_minus_alpha
return f[1:] - f[:-1]
def _integral_unity(egrid: JAXArray) -> JAXArray:
f = jnp.log(egrid)
return f[1:] - f[:-1]
def _powerlaw_integral(egrid: JAXArray, alpha: JAXArray) -> JAXArray:
# Use jnp.not_equal to decide which integration rule to apply
condition = jnp.not_equal(alpha, 1.0)
f1 = _integral_non_unity(egrid, alpha)
f2 = _integral_unity(egrid)
return jnp.where(condition[:-1], f1, f2)
Similarly, in the integral
method for BrokenPL
, isolate the break-region adjustment:
def _adjust_break_integral(egrid_: JAXArray, Eb: float, alpha1: float, alpha2: float, idx: int) -> float:
pb1 = _powerlaw_integral(jnp.hstack([egrid_[idx], Eb]), alpha1)
pb2 = _powerlaw_integral(jnp.hstack([Eb, egrid_[idx + 1]]), alpha2)
return pb1[0] + pb2[0]
# In the BrokenPL.integral method:
mask = egrid[:-1] <= Eb
# ... existing setup ...
idx = jnp.flatnonzero(mask, size=egrid.size - 1)[-1]
f = f.at[idx].set(_adjust_break_integral(egrid_, Eb, alpha1, alpha2, idx))
These small refactorings keep functionality intact while isolating complex operations, making the code easier to read and maintain without reverting any changes.
Summary by Sourcery
Implements several new additive models, including Bent Power Law, Broken Power Law, Smoothly Broken Power Law, and Lorentz models. Also adds a test to evaluate all available models.
New Features:
Tests: