Skip to content

Commit

Permalink
Merge pull request #77 from jacanchaplais/feature/prune-76
Browse files Browse the repository at this point in the history
Prune tagged jet clusters
  • Loading branch information
jacanchaplais authored Feb 23, 2023
2 parents 69c8eed + 53faa88 commit 9f7c40e
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 11 deletions.
11 changes: 7 additions & 4 deletions graphicle/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import numpy.lib.recfunctions as rfn
import pyjet
from pyjet import ClusterSequence, PseudoJet
from typicle import Types

import graphicle as gcl

Expand All @@ -32,8 +31,6 @@
"cluster_pmu",
]

_types = Types()


def azimuth_centre(pmu: gcl.MomentumArray, pt_weight: bool = True) -> float:
"""Calculates the central point in azimuth for a set of particles.
Expand All @@ -60,6 +57,12 @@ def azimuth_centre(pmu: gcl.MomentumArray, pt_weight: bool = True) -> float:
return float(np.angle(pol.sum()))


def pseudorapidity_centre(pmu: gcl.MomentumArray) -> float:
pt_norm = pmu.pt / pmu.pt.sum()
eta_wt_mid = (pmu.eta * pt_norm).sum()
return eta_wt_mid


def combined_mass(
pmu: gcl.MomentumArray | base.VoidVector,
weight: base.DoubleVector | None = None,
Expand Down Expand Up @@ -142,7 +145,7 @@ def _trace_vector(
) -> base.AnyVector:
len_basis = len(basis)
feat_fmt = rfn.structured_to_unstructured if is_structured else lambda x: x
color = np.zeros((len_basis, feat_dim), dtype=_types.double)
color = np.zeros((len_basis, feat_dim), dtype="<f8")
if vertex in basis:
color[basis.index(vertex)] = 1.0
if exclusive is True:
Expand Down
12 changes: 7 additions & 5 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,7 @@ def converter(values: npt.ArrayLike) -> base.AnyVector:

def _truthy(data: ty.Union[base.ArrayBase, base.AdjacencyBase]) -> bool:
"""Defines the truthy value of the graphicle data structures."""
if len(data) == 0:
return False
return True
return not (len(data) == 0)


##################################
Expand Down Expand Up @@ -758,8 +756,12 @@ def __eq__(self, other: base.MaskLike) -> "MaskArray":
def __ne__(self, other: base.MaskLike) -> "MaskArray":
return _mask_neq(self, other)

def copy(self):
return deepcopy(self)
def copy(self) -> "MaskGroup":
mask_copies = map(op.methodcaller("copy"), self._mask_arrays.values())
return self.__class__(
cl.OrderedDict(zip(self._mask_arrays.keys(), mask_copies)),
agg_op=self._agg_op, # type: ignore
)

@property
def names(self) -> ty.List[str]:
Expand Down
85 changes: 83 additions & 2 deletions graphicle/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"hadron_vertices",
"fastjet_clusters",
"leaf_masks",
"centroid_prune",
]


Expand Down Expand Up @@ -300,7 +301,28 @@ def _partition_vertex(
Parameters
----------
mask : gcl.MaskArray
mask : MaskArray
Hard parton descendants.
pcls_in : MaskArray or ndarray[bool_]
The particles entering the hadronisation vertex.
vtx_desc : MaskArray
Particles descending from the hadronisation vertex.
final : MaskArray
Final state particles.
pmu : MomentumArray
Four momenta.
dist_strat : callable
Callable which takes two ``MomentumArray`` instances, and
returns a double array with number of rows and columns equal to
the lengths of the input momenta, respectively. Output should
represent pairwise distance between particles incident on the
hadronisation vertex, and the final state descendants.
Returns
-------
filtered_mask : MaskArray
Input ``MaskArray``, filtered to remove background incident on
the same hadronisation vertex.
"""
mask = mask.copy()
parton_pmu = pmu[pcls_in]
Expand Down Expand Up @@ -358,7 +380,12 @@ def partition_descendants(
if vtx_id not in graph.edges["out"][mask.data]:
continue
mask.data = _partition_vertex(
mask, pcls_in, vtx_desc, graph.final, graph.pmu, dist_strat
mask,
pcls_in,
vtx_desc,
graph.final,
graph.pmu,
dist_strat,
).data
return hier

Expand Down Expand Up @@ -714,3 +741,57 @@ def any_overlap(masks: gcl.MaskGroup) -> bool:
pair_checks = map(np.bitwise_and, *zip(*combos))
overlaps: bool = np.bitwise_or.reduce(tuple(pair_checks), axis=None)
return overlaps


def centroid_prune(
pmu: gcl.MomentumArray,
radius: float,
mask: ty.Optional[gcl.MaskArray] = None,
centre: ty.Optional[ty.Tuple[float, float]] = None,
) -> gcl.MaskArray:
"""For a given ``MomentumArray``, calculate the distance every
particle is from a centroid location, and return a ``MaskArray`` for
all of the particles which are within a given ``radius``.
If ``centre`` is not provided, the transverse momentum weighted
centroid will be used.
:group: select
.. versionadded:: 0.2.4
Parameters
----------
pmu : MomentumArray
Four-momenta for a set of particles.
radius : float
Euclidean distance in the azimuth-pseudorapidity plane from the
centroid, beyond which particles will be filtered out.
mask : MaskArray, optional
If provided, will apply the mask to the passed ``pmu``, and
output ``MaskArray`` will have the same length.
centre : tuple[float, float]
Pseudorapidity and azimuth coordinates for a user-defined
centroid.
Returns
-------
prune_mask : MaskArray
Mask which retains only the particles within ``radius`` of the
centroid.
"""
if mask is not None:
pmu = pmu[mask]
event_mask = np.zeros_like(mask, "<?")
if centre is None:
eta_mid = (pmu.eta * pmu.pt).sum() / pmu.pt.sum()
phi_sum_ = (pmu._xy_pol * pmu.pt).sum()
phi_mid_ = phi_sum_ / np.abs(phi_sum_)
else:
eta_mid, phi_mid = centre
phi_mid_ = np.exp(complex(0, phi_mid))
dist = np.hypot(pmu.eta - eta_mid, np.angle(pmu._xy_pol * phi_mid_.conj()))
is_within = dist < radius
if mask is None:
return gcl.MaskArray(is_within)
event_mask[mask] = is_within
return gcl.MaskArray(event_mask)

0 comments on commit 9f7c40e

Please sign in to comment.