Skip to content

Commit

Permalink
MaskGroup classmethod from numpy + enable numerical subscript +
Browse files Browse the repository at this point in the history
docstrings
  • Loading branch information
jacanchaplais committed Apr 21, 2022
1 parent 120768f commit 8c4a2f2
Showing 1 changed file with 180 additions and 5 deletions.
185 changes: 180 additions & 5 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,25 @@ class MaskGroup(MaskBase):
repr=False, factory=dict, converter=_mask_dict_convert
)

@classmethod
def from_numpy_structured(cls, arr: np.ndarray):
return cls(dict(map(lambda name: (name, arr[name]), arr.dtype.names)))

def __repr__(self):
keys = ", ".join(self.names)
return f"MaskGroup(mask_arrays=[{keys}])"

def __getitem__(self, key):
if not isinstance(key, str):
raise KeyError("Key must be string.")
return self.__class__(
dict(
map(
lambda name_arr: (name_arr[0], name_arr[1][key]),
self._mask_arrays.items(),
)
)
)

return self._mask_arrays[key]

def __setitem__(self, key, mask):
Expand Down Expand Up @@ -229,10 +241,11 @@ class PdgArray(ArrayBase):

data: np.ndarray = array_field("int")
__lookup_table: __PdgRecords = field(init=False, repr=False)
__mega_to_giga: float = 1.0e-3
__mega_to_giga: float = field(init=False, repr=False)

def __attrs_post_init__(self):
self.__lookup_table = self.__PdgRecords()
self.__mega_to_giga: float = 1.0e-3

def __len__(self):
return len(self.data)
Expand Down Expand Up @@ -457,9 +470,19 @@ def __array__(self):
####################
@define
class HelicityArray(ArrayBase):
"""Data structure containing helicity / polarisation values for
particle set.
Attributes
----------
data : ndarray
Helicity values.
"""

data: np.ndarray = array_field("helicity")

def copy(self):
"""Returns a new StatusArray instance with same data."""
return deepcopy(self)

def __getitem__(self, key):
Expand All @@ -479,9 +502,23 @@ def __array__(self):
####################################
@define
class StatusArray(ArrayBase):
"""Data structure containing status values for particle set.
Attributes
----------
data : ndarray
Status codes.
Notes
-----
These codes are specific to the Monte-Carlo event generators which
produced the data.
"""

data: np.ndarray = array_field("h_int")

def copy(self):
"""Returns a new StatusArray instance with same data."""
return deepcopy(self)

def __getitem__(self, key):
Expand Down Expand Up @@ -545,12 +582,22 @@ def hard_mask(self) -> MaskGroup:
#########################################
@define
class ParticleSet(ParticleBase):
"""Combines rich particle description.
"""Composite of data structures containing particle set description.
Attributes
----------
data : ndarray
Structured array containing color / anti-color pairs.
pdg : PdgArray
PDG codes.
pmu : MomentumArray
Four momenta.
color : ColorArray
Color / anti-color pairs.
helicity : HelicityArray
Helicity values.
status : StatusArray
Status codes from Monte-Carlo event generator.
final : MaskArray
Boolean array indicating final state in particle set.
"""

pdg: PdgArray = PdgArray()
Expand Down Expand Up @@ -597,6 +644,32 @@ def from_numpy(
status: Optional[np.ndarray] = None,
final: Optional[np.ndarray] = None,
):
"""Creates a ParticleSet instance directly from numpy arrays.
Parameters
----------
pdg : ndarray, optional
PDG codes.
pmu : ndarray, optional
Four momenta, formatted in columns of (x, y, z, e), or as
a structured array with those fields.
color : ndarray, optional
Color / anti-color pairs, formatted in columns of
(col, acol), or as a structured array with those fields.
helicity : ndarray, optional
Helicity values.
status : ndarray, optional
Status codes from Monte-Carlo event generator.
final : ndarray, optional
Boolean array indicating which particles are final state.
Returns
-------
particle_set : ParticleSet
A composite object, wrapping the data provided in Graphicle
objects, and providing a unified interface to them.
"""

def optional(data_class, data: Optional[np.ndarray]):
return data_class(data) if data is not None else data_class()

Expand All @@ -622,6 +695,21 @@ class _AdjDict(TypedDict):

@define
class AdjacencyList(AdjacencyBase):
"""Describes relations between particles in particle set using a
COO edge list, and provides methods to convert representation.
Attributes
----------
edges : ndarray
COO edge list.
nodes : ndarray
Vertex ids of each particle with at least one edge.
weights : ndarray
Scalar value embedded on each edge.
matrix : ndarray
Adjacency matrix representation.
"""

_data: np.ndarray = array_field("edge")
weights: np.ndarray = array_field("double")

Expand All @@ -636,6 +724,28 @@ def __getitem__(self, key):
key = key.data
return self.__class__(np.array(self._data[key]))

def __add__(self, other_array: "AdjacencyList") -> "AdjacencyList":
"""Combines two AdjacencyList objects by extending edge and
weight lists of both arrays.
If the same edge occurs in both AdjacencyLists, this will lead
to multigraph connectivity.
"""
if not isinstance(other_array, self.__class__):
raise ValueError("Can only add AdjacencyList.")
this_has_weights = len(self.weights) != 0
other_has_weights = len(other_array.weights) != 0
both_weighted = this_has_weights and other_has_weights
both_unweighted = (not this_has_weights) and (not other_has_weights)
if not (both_weighted or both_unweighted):
raise ValueError(
"Mismatch between weights: both adjacency lists "
+ "must either be weighted, or unweighted."
)
return self.__class__(
data=np.concatenate([self._data, other_array._data]),
weights=np.concatenate([self.weights, other_array.weights]),
)

def copy(self):
return deepcopy(self)

Expand Down Expand Up @@ -694,6 +804,9 @@ def to_dicts(
edge_data: Optional[Dict[str, ArrayBase]] = None,
node_data: Optional[Dict[str, ArrayBase]] = None,
) -> _AdjDict:
"""Returns data in dictionary format, which is more easily
parsed by external libraries, such as NetworkX.
"""
if edge_data is None:
edge_data = dict()
if node_data is None:
Expand Down Expand Up @@ -726,6 +839,39 @@ def make_data_dicts(orig: Tuple[Any, ...], data: Dict[str, ArrayBase]):
#####################################################
@define
class Graphicle:
"""Composite object, combining particle set data with relational
information between particles.
Attributes
----------
particles : ParticleSet
Data describing the particles in the set.
adj : AdjacencyList
Connectivity between the particles, to form a graph.
pdg : PdgArray
PDG codes.
pmu : MomentumArray
Four momenta.
color : ColorArray
Color / anti-color pairs.
helicity : HelicityArray
Helicity values.
status : StatusArray
Status codes from Monte-Carlo event generator.
final : MaskArray
Boolean array indicating final state in particle set.
edges : ndarray
COO edge list.
nodes : ndarray
Vertex ids of each particle with at least one edge.
hard_mask : MaskGroup
Identifies which particles participate in the hard process.
For Pythia, this is split into four categories: incoming,
intermediate, outgoing, outgoing_nonperturbative_diffraction.
hard_vertex : int
Vertex at which the hard process is initiated.
"""

particles: ParticleSet = ParticleSet()
adj: AdjacencyList = AdjacencyList()

Expand Down Expand Up @@ -755,6 +901,35 @@ def from_numpy(
edges: Optional[np.ndarray] = None,
weights: Optional[np.ndarray] = None,
):
"""Instantiates a Graphicle object from an optional collection
of numpy arrays.
Parameters
----------
pdg : ndarray, optional
PDG codes.
pmu : ndarray, optional
Four momenta, formatted in columns of (x, y, z, e), or as
a structured array with those fields.
color : ndarray, optional
Color / anti-color pairs, formatted in columns of
(col, acol), or as a structured array with those fields.
helicity : ndarray, optional
Helicity values.
status : ndarray, optional
Status codes from Monte-Carlo event generator.
final : ndarray, optional
Boolean array indicating which particles are final state.
edges : ndarray, optional
COO formatted pairs of vertex ids, of shape (N, 2), where
N is the number of particles in the graph.
Alternatively, supply a structured array with fields
(in, out).
weights : ndarray, optional
Weights to be associated with each edge in the COO edge
list, provided in the same order.
"""

particles = ParticleSet.from_numpy(
pdg=pdg,
pmu=pmu,
Expand Down

0 comments on commit 8c4a2f2

Please sign in to comment.