Skip to content

Commit

Permalink
Merge pull request #154 from jacanchaplais/patch/numba-pythonic
Browse files Browse the repository at this point in the history
Idiomatic improvements to numba loops
  • Loading branch information
jacanchaplais authored Sep 11, 2023
2 parents 76157bf + 1b8b2a2 commit 0c828dc
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions graphicle/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,34 +418,35 @@ def flow_trace(
def _rapidity(
energy: base.DoubleVector, z: base.DoubleVector, zero_tol: float
) -> base.DoubleVector:
"""Numpy ufunc to calculate the rapidity of a set of particles.
"""Calculate the rapidity of a set of particles using energy
and longitudinal components of their four-momenta.
Parameters
----------
energy, z : array_like
energy, z : ndarray[float64]
Components of the particles' four-momenta.
zero_tol : float
Absolute tolerance for energy values to be considered close to
zero.
Returns
-------
ndarray or float
ndarray[float64]
Rapidity of the particles.
"""
rap = np.empty_like(energy)
for i in range(len(rap)):
z_ = abs(z[i])
diff = energy[i] - z_
if abs(diff) < zero_tol:
rap_ = math.inf
for idx, (e_val, z_val) in enumerate(zip(energy, z)):
z_val_abs = abs(z_val)
diff = e_val - z_val_abs
if abs(diff) > zero_tol:
rap_val = 0.5 * math.log((e_val + z_val_abs) / diff)
else:
rap_ = 0.5 * math.log((energy[i] + z_) / diff)
rap[i] = math.copysign(rap_, z[i])
rap_val = math.inf
rap[idx] = math.copysign(rap_val, z_val)
return rap


@nb.vectorize([nb.float64(nb.float64, nb.float64)])
@nb.vectorize("float64(float64, float64)")
def _root_diff_two_squares(
x1: base.DoubleUfunc, x2: base.DoubleUfunc
) -> base.DoubleUfunc:
Expand Down Expand Up @@ -480,9 +481,7 @@ def _root_diff_two_squares(


@nb.njit(
nb.float64[:, :](
nb.float64[:], nb.float64[:], nb.complex128[:], nb.complex128[:]
),
"float64[:, :](float64[:], float64[:], complex128[:], complex128[:])",
parallel=True,
)
def _delta_R(
Expand Down Expand Up @@ -522,10 +521,7 @@ def _delta_R(
return result


@nb.njit(
nb.float64[:, :](nb.float64[:], nb.complex128[:]),
parallel=True,
)
@nb.njit("float64[:, :](float64[:], complex128[:])", parallel=True)
def _delta_R_symmetric(
rapidity: base.DoubleVector, xy_pol: base.ComplexVector
) -> base.DoubleVector:
Expand Down

0 comments on commit 0c828dc

Please sign in to comment.