Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache creation of compound masks #4612

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 46 additions & 35 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,45 +936,56 @@ def _split_by_compound_indices(self, compound, stable_sort=False):
n_compounds : int
The number of individual compounds.
"""
# Caching would help here, especially when repeating the operation
# over different frames, since these masks are coordinate-independent.
# However, cache must be invalidated whenever new compound indices are
# modified, which is not yet implemented.
# Also, should we include here the grouping for 'group', which is
# Should we include here the grouping for 'group', which is
# essentially a non-split?

cache_key = f"{compound}_masks"
compound_indices = self._get_compound_indices(compound)
compound_sizes = np.bincount(compound_indices)
size_per_atom = compound_sizes[compound_indices]
compound_sizes = compound_sizes[compound_sizes != 0]
unique_compound_sizes = unique_int_1d(compound_sizes)

# Are we already sorted? argsorting and fancy-indexing can be expensive
# so we do a quick pre-check.
needs_sorting = np.any(np.diff(compound_indices) < 0)
if needs_sorting:
# stable sort ensures reproducibility, especially concerning who
# gets to be a compound's atom[0] and be a reference for unwrap.
if stable_sort:
sort_indices = np.argsort(compound_indices, kind='stable')
else:
# Quicksort
sort_indices = np.argsort(compound_indices)
# We must sort size_per_atom accordingly (Issue #3352).
size_per_atom = size_per_atom[sort_indices]

compound_masks = []
atom_masks = []
for compound_size in unique_compound_sizes:
compound_masks.append(compound_sizes == compound_size)

# create new cache or invalidate cache when compound indices changed
if (
cache_key not in self._cache
or np.all(self._cache[cache_key]["compound_indices"]
!= compound_indices)):
compound_sizes = np.bincount(compound_indices)
size_per_atom = compound_sizes[compound_indices]
compound_sizes = compound_sizes[compound_sizes != 0]
unique_compound_sizes = unique_int_1d(compound_sizes)

# Are we already sorted? argsorting and fancy-indexing can be
# expensive so we do a quick pre-check.
needs_sorting = np.any(np.diff(compound_indices) < 0)
if needs_sorting:
atom_masks.append(sort_indices[size_per_atom == compound_size]
.reshape(-1, compound_size))
else:
atom_masks.append(np.where(size_per_atom == compound_size)[0]
.reshape(-1, compound_size))
# stable sort ensures reproducibility, especially concerning
# who gets to be a compound's atom[0] and be a reference for
# unwrap.
if stable_sort:
sort_indices = np.argsort(compound_indices, kind='stable')
else:
# Quicksort
sort_indices = np.argsort(compound_indices)
# We must sort size_per_atom accordingly (Issue #3352).
size_per_atom = size_per_atom[sort_indices]

compound_masks = []
atom_masks = []
for compound_size in unique_compound_sizes:
compound_masks.append(compound_sizes == compound_size)
if needs_sorting:
atom_masks.append(sort_indices[size_per_atom
== compound_size]
.reshape(-1, compound_size))
else:
atom_masks.append(np.where(size_per_atom
== compound_size)[0]
.reshape(-1, compound_size))

self._cache[cache_key] = {
"compound_indices": compound_indices,
"data": (atom_masks, compound_masks, len(compound_sizes))
}
Comment on lines +983 to +986
Copy link
Contributor Author

@PicoCentauri PicoCentauri Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is the canonical approach for caching. But the @cached decorator does not work here I think.


return atom_masks, compound_masks, len(compound_sizes)
return self._cache[cache_key]["data"]

@warn_if_not_unique
@_pbc_to_wrap
Expand Down Expand Up @@ -3200,7 +3211,7 @@ def select_atoms(self, sel, *othersel, periodic=True, rtol=1e-05,
universe = mda.Universe(PSF, DCD)
guessed_elements = guess_types(universe.atoms.names)
universe.add_TopologyAttr('elements', guessed_elements)

.. doctest:: AtomGroup.select_atoms.smarts

>>> universe.select_atoms("smarts C", smarts_kwargs={"maxMatches": 100})
Expand Down
38 changes: 37 additions & 1 deletion testsuite/MDAnalysisTests/core/test_accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
from numpy.testing import assert_equal, assert_almost_equal, assert_allclose

import MDAnalysis as mda
from MDAnalysis.exceptions import DuplicateWarning, NoDataError
Expand Down Expand Up @@ -291,3 +291,39 @@ def test_quadrupole_moment_fragments(self, group):
assert_almost_equal(quadrupoles,
np.array([0., 0.0011629, 0.1182701, 0.6891748
])) and len(quadrupoles) == n_compounds


class TestCache:
@pytest.fixture()
def group(self):
return mda.Universe(PSF, DCD).atoms

def test_cache(self, group):
"""Test that one cache per compound is created."""
group_nocache = group.copy()
group_cache = group.copy()

for compound in ['residues', 'fragments']:
actual = group_nocache.accumulate("masses", compound=compound)
desired = group_cache.accumulate("masses", compound=compound)

assert_allclose(actual, desired)
group_nocache._cache.pop(f"{compound}_masks")

@pytest.mark.parametrize("compound",
['residues', 'fragments'])
def test_cache_updating(self, group, compound):
"""Test caching of compound_masks for updating atomgroups."""
kwargs = {"attribute": "masses", "compound": compound}

group_nocache = group.select_atoms("prop z < 1.0", updating=True)
group_cache = group.select_atoms("prop z < 1.0", updating=True)

assert_allclose(group_nocache.accumulate(**kwargs),
group_cache.accumulate(**kwargs))

# Clear cache and forward to next frame
group_nocache._cache.pop(f"{compound}_masks")
group.universe.trajectory.next()
assert_allclose(group_nocache.accumulate(**kwargs),
group_cache.accumulate(**kwargs))
Loading