File: backends.py

package info (click to toggle)
mdanalysis 2.10.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,696 kB
  • sloc: python: 92,135; ansic: 8,156; makefile: 215; sh: 138
file content (338 lines) | stat: -rw-r--r-- 10,094 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
"""Analysis backends --- :mod:`MDAnalysis.analysis.backends`
============================================================

.. versionadded:: 2.8.0


The :mod:`backends` module provides :class:`BackendBase` base class to
implement custom execution backends for
:meth:`MDAnalysis.analysis.base.AnalysisBase.run` and its
subclasses.

.. SeeAlso:: :ref:`parallel-analysis`

.. _backends:

Backends
--------

Three built-in backend classes are provided:

* *serial*: :class:`BackendSerial`, that is equivalent to using no
  parallelization and is the default

* *multiprocessing*: :class:`BackendMultiprocessing` that supports
  parallelization via standard Python :mod:`multiprocessing` module
  and uses default :mod:`pickle` serialization

* *dask*: :class:`BackendDask`, that uses the same process-based
  parallelization as :class:`BackendMultiprocessing`, but different
  serialization algorithm via `dask <https://dask.org/>`_ (see `dask
  serialization algorithms
  <https://distributed.dask.org/en/latest/serialization.html>`_ for details)

Classes
-------

"""

import warnings
from typing import Callable
from MDAnalysis.lib.util import is_installed


class BackendBase:
    """Base class for backend implementation.

    Initializes an instance and performs checks for its validity, such as
    ``n_workers`` and possibly other ones.

    Parameters
    ----------
    n_workers : int
        number of workers (usually, processes) over which the work is split

    Examples
    --------
    .. code-block:: python

        from MDAnalysis.analysis.backends import BackendBase

        class ThreadsBackend(BackendBase):
            def apply(self, func, computations):
                from multiprocessing.dummy import Pool

                with Pool(processes=self.n_workers) as pool:
                    results = pool.map(func, computations)
                return results

        import MDAnalysis as mda
        from MDAnalysis.tests.datafiles import PSF, DCD
        from MDAnalysis.analysis.rms import RMSD

        u = mda.Universe(PSF, DCD)
        ref = mda.Universe(PSF, DCD)

        R = RMSD(u, ref)

        n_workers = 2
        backend = ThreadsBackend(n_workers=n_workers)
        R.run(backend=backend, unsupported_backend=True)

    .. warning::
        Using `ThreadsBackend` above will lead to erroneous results, since it
        is an educational example. Do not use it for real analysis.


    .. versionadded:: 2.8.0

    """

    def __init__(self, n_workers: int):
        self.n_workers = n_workers
        self._validate()

    def _get_checks(self):
        """Get dictionary with ``condition: error_message`` pairs that ensure the
        validity of the backend instance

        Returns
        -------
        dict
            dictionary with ``condition: error_message`` pairs that will get
            checked during ``_validate()`` run
        """
        return {
            isinstance(self.n_workers, int)
            and self.n_workers
            > 0: f"n_workers should be positive integer, got {self.n_workers=}",
        }

    def _get_warnings(self):
        """Get dictionary with ``condition: warning_message`` pairs that ensure
        the good usage of the backend instance

        Returns
        -------
        dict
            dictionary with ``condition: warning_message`` pairs that will get
            checked during ``_validate()`` run
        """
        return dict()

    def _validate(self):
        """Check correctness (e.g. ``dask`` is installed if using ``backend='dask'``)
        and good usage (e.g. ``n_workers=1`` if backend is serial) of the backend

        Raises
        ------
        ValueError
            if one of the conditions in :meth:`_get_checks` is ``True``
        """
        for check, msg in self._get_checks().items():
            if not check:
                raise ValueError(msg)
        for check, msg in self._get_warnings().items():
            if not check:
                warnings.warn(msg)

    def apply(self, func: Callable, computations: list) -> list:
        """map function `func` to all tasks in the `computations` list

        Main method that will get called when using an instance of
        ``BackendBase``.  It is equivalent to running ``[func(item) for item in
        computations]`` while using the parallel backend capabilities.

        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to

        Returns
        -------
        list
            list of results of the function

        """
        raise NotImplementedError


class BackendSerial(BackendBase):
    """A built-in backend that does serial execution of the function, without any
    parallelization.

    Parameters
    ----------
    n_workers : int
        Is ignored in this class, and if ``n_workers`` > 1, a warning will be
        given.


    .. versionadded:: 2.8.0
    """

    def _get_warnings(self):
        """Get dictionary with ``condition: warning_message`` pairs that ensure
        the good usage of the backend instance. Here, it checks if the number
        of workers is not 1, otherwise gives warning.

        Returns
        -------
        dict
            dictionary with ``condition: warning_message`` pairs that will get
            checked during ``_validate()`` run
        """
        return {
            self.n_workers
            == 1: "n_workers is ignored when executing with backend='serial'"
        }

    def apply(self, func: Callable, computations: list) -> list:
        """
        Serially applies `func` to each task object in ``computations``.

        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to

        Returns
        -------
        list
            list of results of the function
        """
        return [func(task) for task in computations]


class BackendMultiprocessing(BackendBase):
    """A built-in backend that executes a given function using the
    :meth:`multiprocessing.Pool.map <multiprocessing.pool.Pool.map>` method.

    Parameters
    ----------
    n_workers : int
        number of processes in :class:`multiprocessing.Pool
        <multiprocessing.pool.Pool>` to distribute the workload
        between. Must be a positive integer.

    Examples
    --------

    .. code-block:: python

        from MDAnalysis.analysis.backends import BackendMultiprocessing
        import multiprocessing as mp

        backend_obj = BackendMultiprocessing(n_workers=mp.cpu_count())


    .. versionadded:: 2.8.0

    """

    def apply(self, func: Callable, computations: list) -> list:
        """Applies `func` to each object in ``computations`` using `multiprocessing`'s `Pool.map`.

        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to

        Returns
        -------
        list
            list of results of the function
        """
        from multiprocessing import Pool

        with Pool(processes=self.n_workers) as pool:
            results = pool.map(func, computations)
        return results


class BackendDask(BackendBase):
    """A built-in backend that executes a given function with *dask*.

    Execution is performed with the :func:`dask.compute` function of
    :class:`dask.delayed.Delayed` object (created with
    :func:`dask.delayed.delayed`) with ``scheduler='processes'`` and
    ``chunksize=1`` (this ensures uniform distribution of tasks among
    processes). Requires the `dask package <https://docs.dask.org/en/stable/>`_
    to be `installed <https://docs.dask.org/en/stable/install.html>`_.

    Parameters
    ----------
    n_workers : int
        number of processes in to distribute the workload
        between. Must be a positive integer. Workers are actually
        :class:`multiprocessing.pool.Pool` processes, but they use a different and
        more flexible `serialization protocol
        <https://docs.dask.org/en/stable/phases-of-computation.html#graph-serialization>`_.

    Examples
    --------

    .. code-block:: python

        from MDAnalysis.analysis.backends import BackendDask
        import multiprocessing as mp

        backend_obj = BackendDask(n_workers=mp.cpu_count())


    .. versionadded:: 2.8.0

    """

    def apply(self, func: Callable, computations: list) -> list:
        """Applies `func` to each object in ``computations``.

        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to

        Returns
        -------
        list
            list of results of the function
        """
        from dask.delayed import delayed
        import dask

        computations = [delayed(func)(task) for task in computations]
        results = dask.compute(
            computations,
            scheduler="processes",
            chunksize=1,
            num_workers=self.n_workers,
        )[0]
        return results

    def _get_checks(self):
        """Get dictionary with ``condition: error_message`` pairs that ensure the
        validity of the backend instance. Here checks if ``dask`` module is
        installed in the environment.

        Returns
        -------
        dict
            dictionary with ``condition: error_message`` pairs that will get
            checked during ``_validate()`` run
        """
        base_checks = super()._get_checks()
        checks = {
            is_installed("dask"): (
                "module 'dask' is missing. Please install 'dask': "
                "https://docs.dask.org/en/stable/install.html"
            )
        }
        return base_checks | checks