Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit numpy thread usage for Transformation classes #2950

Merged
merged 33 commits into from
Apr 10, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b8441ff
add threadlimit deco
Sep 19, 2020
604fa1b
add threadlimit deco
yuxuanzhuang Sep 19, 2020
0ed46f9
add dep
yuxuanzhuang Sep 19, 2020
b41a739
runtime change
yuxuanzhuang Sep 23, 2020
3941281
context mananger instead of decor
yuxuanzhuang Oct 13, 2020
326e105
working decor
yuxuanzhuang Oct 16, 2020
b4daa8d
create TransformationBase and thread limit
yuxuanzhuang Oct 16, 2020
d686f06
travis threadpool
yuxuanzhuang Oct 16, 2020
af6bab8
deco to context due to picklibility
yuxuanzhuang Oct 16, 2020
5f985e9
pep
yuxuanzhuang Oct 16, 2020
efbc3f8
docs
yuxuanzhuang Nov 4, 2020
ded5f84
Merge remote-tracking branch 'mda_origin/develop' into trans_single_t…
yuxuanzhuang Nov 4, 2020
e2cd996
change to kwargs
yuxuanzhuang Nov 4, 2020
f084d33
add test for transformation base
yuxuanzhuang Nov 4, 2020
67aeda2
changelog
yuxuanzhuang Nov 4, 2020
d363c0b
appveyor
yuxuanzhuang Nov 4, 2020
34d370e
travis threadpool for arm
yuxuanzhuang Nov 4, 2020
01570c4
remove deco threadlimit
yuxuanzhuang Nov 4, 2020
383b1f5
doc for transformation
yuxuanzhuang Nov 9, 2020
317e5e2
merge to dev
yuxuanzhuang Nov 9, 2020
0fb4d85
typo
yuxuanzhuang Nov 9, 2020
9365e5f
merge to develop
yuxuanzhuang Apr 6, 2021
353bfa9
make threadpoolctl a requirement
yuxuanzhuang Apr 7, 2021
fdd8f52
base documentation
yuxuanzhuang Apr 7, 2021
edb5ef5
doc for transformation fix
yuxuanzhuang Apr 7, 2021
7884210
box dimension rework
yuxuanzhuang Apr 7, 2021
636ae12
add threadpoolctl to azure
yuxuanzhuang Apr 7, 2021
1ea51df
add threadpoolctl to gh ci
yuxuanzhuang Apr 7, 2021
197deb9
merge to develop
yuxuanzhuang Apr 7, 2021
e2ca135
add threadpoolctl to ppc64le
yuxuanzhuang Apr 8, 2021
aba771f
cov notimplement error transformation
yuxuanzhuang Apr 8, 2021
890e2b3
add note for maxthread=1
yuxuanzhuang Apr 8, 2021
17d4cb4
changelog for threadpooltcl
yuxuanzhuang Apr 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ env:
- MAIN_CMD="pytest ${PYTEST_LIST}"
- SETUP_CMD="${PYTEST_FLAGS}"
- BUILD_CMD="pip install -e package/ && (cd testsuite/ && python setup.py build)"
- CONDA_MIN_DEPENDENCIES="mmtf-python biopython networkx cython matplotlib scipy griddataformats hypothesis gsd codecov"
- CONDA_MIN_DEPENDENCIES="mmtf-python biopython networkx cython matplotlib scipy griddataformats hypothesis gsd codecov threadpoolctl"
- CONDA_DEPENDENCIES="${CONDA_MIN_DEPENDENCIES} seaborn>=0.7.0 clustalw=2.1 netcdf4 scikit-learn joblib>=0.12 chemfiles tqdm>=4.43.0 tidynamics>=1.0.0 rdkit>=2020.03.1 h5py"
- CONDA_CHANNELS='biobuilds conda-forge'
- CONDA_CHANNEL_PRIORITY=True
Expand Down
23 changes: 23 additions & 0 deletions package/MDAnalysis/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@
import functools
from functools import wraps
import textwrap
from contextlib import ContextDecorator
from threadpoolctl import threadpool_limits

import mmtf
import numpy as np
Expand Down Expand Up @@ -2353,3 +2355,24 @@ def check_box(box):
if np.all(box[3:] == 90.):
return 'ortho', box[:3]
return 'tri_vecs', triclinic_vectors(box)


class threadpool_limits_decorator(threadpool_limits, ContextDecorator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm probably missing something very obvious, is this being used anywhere anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not really. It turns out to be not that useful in this case.

def __init__(self, limits=None, user_api=None):
self._limits, self._user_api, self._prefixes = \
self._check_params(limits, user_api)

def __enter__(self):
self._original_info = self._set_threadpool_limits()
self.origin_num_threads = self.get_original_num_threads()
return self

def __exit__(self, *exc):
self.unregister()

def unregister(self):
if self._original_info is not None:
for module in self._original_info:
module.set_num_threads(
self.origin_num_threads[module.user_api]
)
1 change: 1 addition & 0 deletions package/MDAnalysis/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def wrapped(ts):
method instead of being written as a function/closure.
"""

from .base import TransformationBase
from .translate import translate, center_in_box
from .rotate import rotateby
from .positionaveraging import PositionAverager
Expand Down
42 changes: 42 additions & 0 deletions package/MDAnalysis/transformations/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
#
# MDAnalysis --- https://www.mdanalysis.org
# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors
# (see the file AUTHORS for the full list of names)
#
# Released under the GNU Public Licence, v2 or any higher version
#
# Please cite your use of MDAnalysis in published work:
#
# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler,
# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein.
# MDAnalysis: A Python package for the rapid analysis of molecular dynamics
# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th
# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy.
#
# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein.
# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#

"""\
Transformations Base Class --- :mod:`MDAnalysis.transformations.base`
=================================================================

.. autoclass:: TransformationBase

"""
from threadpoolctl import threadpool_limits
orbeckst marked this conversation as resolved.
Show resolved Hide resolved


class TransformationBase(object):
IAlibay marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, max_threads):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give this a default argument, e.g. max_threads=None, or even make it a default class attribute? My concern is that users who don't understand or care about parallelisation will try to make custom transformations and override __init__ or forget to pass it an argument, with predictable errors upon initialisation or __call__ing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @lilyminium. Most users will not care about the number of threads used so there should be a default value that 1) let users focus on what they care about and 2) keep the existing code working.

self.max_threads = max_threads

def __call__(self, ts):
with threadpool_limits(self.max_threads):
return self._transform(ts)

def _transform(self):
IAlibay marked this conversation as resolved.
Show resolved Hide resolved
pass
18 changes: 12 additions & 6 deletions package/MDAnalysis/transformations/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from ..analysis import align
from ..lib.transformations import euler_from_matrix, euler_matrix

from .base import TransformationBase

class fit_translation(object):

class fit_translation(TransformationBase):
"""Translates a given AtomGroup so that its center of geometry/mass matches
the respective center of the given reference. A plane can be given by the
user using the option `plane`, and will result in the removal of
Expand Down Expand Up @@ -85,7 +87,9 @@ class fit_translation(object):
The transformation was changed from a function/closure to a class
with ``__call__``.
"""
def __init__(self, ag, reference, plane=None, weights=None):
def __init__(self, ag, reference, plane=None, weights=None, max_threads=1):
super().__init__(max_threads)

self.ag = ag
self.reference = reference
self.plane = plane
Expand Down Expand Up @@ -117,7 +121,7 @@ def __init__(self, ag, reference, plane=None, weights=None):
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)

def __call__(self, ts):
def _transform(self, ts):
mobile_com = np.asarray(self.mobile.atoms.center(self.weights),
np.float32)
vector = self.ref_com - mobile_com
Expand All @@ -128,7 +132,7 @@ def __call__(self, ts):
return ts


class fit_rot_trans(object):
class fit_rot_trans(TransformationBase):
"""Perform a spatial superposition by minimizing the RMSD.

Spatially align the group of atoms `ag` to `reference` by doing a RMSD
Expand Down Expand Up @@ -175,7 +179,9 @@ class fit_rot_trans(object):
-------
MDAnalysis.coordinates.base.Timestep
"""
def __init__(self, ag, reference, plane=None, weights=None):
def __init__(self, ag, reference, plane=None, weights=None, max_threads=1):
super().__init__(max_threads)

self.ag = ag
self.reference = reference
self.plane = plane
Expand Down Expand Up @@ -207,7 +213,7 @@ def __init__(self, ag, reference, plane=None, weights=None):
self.ref_com = self.ref.center(self.weights)
self.ref_coordinates = self.ref.atoms.positions - self.ref_com

def __call__(self, ts):
def _transform(self, ts):
mobile_com = self.mobile.atoms.center(self.weights)
mobile_coordinates = self.mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates,
Expand Down
9 changes: 6 additions & 3 deletions package/MDAnalysis/transformations/positionaveraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
import numpy as np
import warnings

from .base import TransformationBase

class PositionAverager(object):

class PositionAverager(TransformationBase):
"""
Averages the coordinates of a given timestep so that the coordinates
of the AtomGroup correspond to the average positions of the N previous
Expand Down Expand Up @@ -136,7 +138,8 @@ class PositionAverager(object):

"""

def __init__(self, avg_frames, check_reset=True):
def __init__(self, avg_frames, check_reset=True, max_threads=1):
super().__init__(max_threads)
self.avg_frames = avg_frames
self.check_reset = check_reset
self.current_avg = 0
Expand All @@ -162,7 +165,7 @@ def rollposx(self, ts):
self.coord_array = np.roll(self.coord_array, 1, axis=2)
self.coord_array[..., 0] = ts.positions.copy()

def __call__(self, ts):
def _transform(self, ts):
# calling the same timestep will not add new data to coord_array
# This can prevent from getting different values when
# call `u.trajectory[i]` multiple times.
Expand Down
11 changes: 8 additions & 3 deletions package/MDAnalysis/transformations/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from ..lib.transformations import rotation_matrix
from ..lib.util import get_weights

from .base import TransformationBase

class rotateby(object):

class rotateby(TransformationBase):
'''
Rotates the trajectory by a given angle on a given axis. The axis is defined by
the user, combining the direction vector and a point. This point can be the center
Expand Down Expand Up @@ -118,7 +120,10 @@ def __init__(self,
point=None,
ag=None,
weights=None,
wrap=False):
wrap=False,
max_threads=1):
super().__init__(max_threads)

self.angle = angle
self.direction = direction
self.point = point
Expand Down Expand Up @@ -162,7 +167,7 @@ def __init__(self,
else:
raise ValueError('A point or an AtomGroup must be specified')

def __call__(self, ts):
def _transform(self, ts):
if self.point is None:
position = self.center_method()
else:
Expand Down
19 changes: 13 additions & 6 deletions package/MDAnalysis/transformations/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
import numpy as np
from functools import partial

from .base import TransformationBase

class translate(object):

class translate(TransformationBase):
"""
Translates the coordinates of a given :class:`~MDAnalysis.coordinates.base.Timestep`
instance by a given vector.
Expand All @@ -59,20 +61,22 @@ class translate(object):
:class:`~MDAnalysis.coordinates.base.Timestep` object

"""
def __init__(self, vector):
def __init__(self, vector, max_threads=1):
super().__init__(max_threads)

self.vector = vector

if len(self.vector) > 2:
self.vector = np.float32(self.vector)
else:
raise ValueError("{} vector is too short".format(self.vector))

def __call__(self, ts):
def _transform(self, ts):
ts.positions += self.vector
return ts


class center_in_box(object):
class center_in_box(TransformationBase):
"""
Translates the coordinates of a given :class:`~MDAnalysis.coordinates.base.Timestep`
instance so that the center of geometry/mass of the given :class:`~MDAnalysis.core.groups.AtomGroup`
Expand Down Expand Up @@ -112,7 +116,10 @@ class center_in_box(object):
The transformation was changed from a function/closure to a class
with ``__call__``.
"""
def __init__(self, ag, center='geometry', point=None, wrap=False):
def __init__(self, ag, center='geometry', point=None, wrap=False,
max_threads=1):
super().__init__(max_threads)

self.ag = ag
self.center = center
self.point = point
Expand Down Expand Up @@ -140,7 +147,7 @@ def __init__(self, ag, center='geometry', point=None, wrap=False):
raise ValueError(f'{self.ag} is not an AtomGroup object') \
from None

def __call__(self, ts):
def _transform(self, ts):
if self.point is None:
boxcenter = np.sum(ts.triclinic_dimensions, axis=0) / 2
else:
Expand Down
18 changes: 12 additions & 6 deletions package/MDAnalysis/transformations/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@

from ..lib._cutil import make_whole

from .base import TransformationBase
IAlibay marked this conversation as resolved.
Show resolved Hide resolved

class wrap(object):

class wrap(TransformationBase):
"""
Shift the contents of a given AtomGroup back into the unit cell. ::

Expand Down Expand Up @@ -85,16 +87,18 @@ class wrap(object):
The transformation was changed from a function/closure to a class
with ``__call__``.
"""
def __init__(self, ag, compound='atoms'):
def __init__(self, ag, compound='atoms', max_threads=1):
super().__init__(max_threads)

self.ag = ag
self.compound = compound

def __call__(self, ts):
def _transform(self, ts):
self.ag.wrap(compound=self.compound)
return ts


class unwrap(object):
class unwrap(TransformationBase):
"""
Move all atoms in an AtomGroup so that bonds don't split over images

Expand Down Expand Up @@ -139,15 +143,17 @@ class unwrap(object):
The transformation was changed from a function/closure to a class
with ``__call__``.
"""
def __init__(self, ag):
def __init__(self, ag, max_threads=1):
super().__init__(max_threads)

self.ag = ag

try:
self.ag.fragments
except AttributeError:
raise AttributeError("{} has no fragments".format(self.ag))

def __call__(self, ts):
def _transform(self, ts):
for frag in self.ag.fragments:
make_whole(frag)
return ts