File: results.py

package info (click to toggle)
mdanalysis 2.10.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,696 kB
  • sloc: python: 92,135; ansic: 8,156; makefile: 215; sh: 138
file content (330 lines) | stat: -rw-r--r-- 9,870 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""Analysis results and their aggregation --- :mod:`MDAnalysis.analysis.results`
================================================================================

Module introduces two classes, :class:`Results` and :class:`ResultsGroup`,
used for storing and aggregating data in
:meth:`MDAnalysis.analysis.base.AnalysisBase.run()`, respectively.


Classes
-------

The :class:`Results` class is an extension of a built-in dictionary
type, that holds all assigned attributes in :attr:`self.data` and 
allows for access either via dict-like syntax, or via class-like syntax:

.. code-block:: python

    from MDAnalysis.analysis.results import Results
    r = Results()
    r.array = [1, 2, 3, 4]
    assert r['array'] == r.array == [1, 2, 3, 4]


The :class:`ResultsGroup` can merge multiple :class:`Results` objects.
It is mainly used by :class:`MDAnalysis.analysis.base.AnalysisBase` class, 
that uses :meth:`ResultsGroup.merge()` method to aggregate results from
multiple workers, initialized during a parallel run:

.. code-block:: python

    from MDAnalysis.analysis.results import Results, ResultsGroup
    import numpy as np
    
    r1, r2 = Results(), Results()
    r1.masses = [1, 2, 3, 4, 5]
    r2.masses = [0, 0, 0, 0]
    r1.vectors = np.arange(10).reshape(5, 2)
    r2.vectors = np.arange(8).reshape(4, 2)

    group = ResultsGroup(
        lookup = {
            'masses': ResultsGroup.flatten_sequence,
            'vectors': ResultsGroup.ndarray_vstack
            }
        )

    r = group.merge([r1, r2])
    assert r.masses == list((*r1.masses, *r2.masses))
    assert (r.vectors == np.vstack([r1.vectors, r2.vectors])).all()
"""

from collections import UserDict
import numpy as np
from typing import Callable, Sequence


class Results(UserDict):
    r"""Container object for storing results.

    :class:`Results` are dictionaries that provide two ways by which values
    can be accessed: by dictionary key ``results["value_key"]`` or by object
    attribute, ``results.value_key``. :class:`Results` stores all results
    obtained from an analysis after calling :meth:`~AnalysisBase.run()`.

    The implementation is similar to the :class:`sklearn.utils.Bunch`
    class in `scikit-learn`_.

    .. _`scikit-learn`: https://scikit-learn.org/
    .. _`sklearn.utils.Bunch`: https://scikit-learn.org/stable/modules/generated/sklearn.utils.Bunch.html

    Raises
    ------
    AttributeError
        If an assigned attribute has the same name as a default attribute.

    ValueError
        If a key is not of type ``str`` and therefore is not able to be
        accessed by attribute.

    Examples
    --------
    >>> from MDAnalysis.analysis.base import Results
    >>> results = Results(a=1, b=2)
    >>> results['b']
    2
    >>> results.b
    2
    >>> results.a = 3
    >>> results['a']
    3
    >>> results.c = [1, 2, 3, 4]
    >>> results['c']
    [1, 2, 3, 4]


    .. versionadded:: 2.0.0

    .. versionchanged:: 2.8.0
        Moved :class:`Results` to :mod:`MDAnalysis.analysis.results`
    """

    def _validate_key(self, key):
        if key in dir(self):
            raise AttributeError(
                f"'{key}' is a protected dictionary attribute"
            )
        elif isinstance(key, str) and not key.isidentifier():
            raise ValueError(f"'{key}' is not a valid attribute")

    def __init__(self, *args, **kwargs):
        kwargs = dict(*args, **kwargs)
        if "data" in kwargs.keys():
            raise AttributeError(f"'data' is a protected dictionary attribute")
        self.__dict__["data"] = {}
        self.update(kwargs)

    def __setitem__(self, key, item):
        self._validate_key(key)
        super().__setitem__(key, item)

    def __setattr__(self, attr, val):
        if attr == "data":
            super().__setattr__(attr, val)
        else:
            self.__setitem__(attr, val)

    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError as err:
            raise AttributeError(
                f"'Results' object has no attribute '{attr}'"
            ) from err

    def __delattr__(self, attr):
        try:
            del self[attr]
        except KeyError as err:
            raise AttributeError(
                f"'Results' object has no attribute '{attr}'"
            ) from err

    def __getstate__(self):
        return self.data

    def __setstate__(self, state):
        self.data = state


class ResultsGroup:
    """
    Management and aggregation of results stored in :class:`Results` instances.

    A :class:`ResultsGroup` is an optional description for :class:`Result` "dictionaries"
    that are used in analysis classes based on :class:`AnalysisBase`. For each *key* in a
    :class:`Result` it describes how multiple pieces of the data held under the key are
    to be aggregated. This approach is necessary when parts of a trajectory are analyzed
    independently (e.g., in parallel) and then need to me merged (with :meth:`merge`) to
    obtain a complete data set.

    Parameters
    ----------
    lookup : dict[str, Callable], optional
        aggregation functions lookup dict, by default None

    Examples
    --------

    .. code-block:: python

        from MDAnalysis.analysis.results import ResultsGroup, Results
        group = ResultsGroup(lookup={'mass': ResultsGroup.float_mean})
        obj1 = Results(mass=1)
        obj2 = Results(mass=3)
        assert {'mass': 2.0} == group.merge([obj1, obj2])


    .. code-block:: python

        # you can also set `None` for those attributes that you want to skip
        lookup = {'mass': ResultsGroup.float_mean, 'trajectory': None}
        group = ResultsGroup(lookup)
        objects = [Results(mass=1, skip=None), Results(mass=3, skip=object)]
        assert group.merge(objects, require_all_aggregators=False) == {'mass': 2.0}

    .. versionadded:: 2.8.0
    """

    def __init__(self, lookup: dict[str, Callable] = None):
        self._lookup = lookup

    def merge(
        self, objects: Sequence[Results], require_all_aggregators: bool = True
    ) -> Results:
        """Merge multiple Results into a single Results instance.

        Merge multiple :class:`Results` instances into a single one, using the
        `lookup` dictionary to determine the appropriate aggregator functions for
        each named results attribute. If the resulting object only contains a single
        element, it just returns it without using any aggregators.

        Parameters
        ----------
        objects : Sequence[Results]
            Multiple :class:`Results` instances with the same data attributes.
        require_all_aggregators : bool, optional
            if True, raise an exception when no aggregation function for a
            particular argument is found. Allows to skip aggregation for the
            parameters that aren't needed in the final object --
            see :class:`ResultsGroup`.

        Returns
        -------
        Results
            merged :class:`Results`

        Raises
        ------
        ValueError
            if no aggregation function for a key is found and ``require_all_aggregators=True``
        """
        if len(objects) == 1:
            merged_results = objects[0]
            return merged_results

        merged_results = Results()
        for key in objects[0].keys():
            agg_function = self._lookup.get(key, None)
            if agg_function is not None:
                results_of_t = [obj[key] for obj in objects]
                merged_results[key] = agg_function(results_of_t)
            elif require_all_aggregators:
                raise ValueError(f"No aggregation function for {key=}")
        return merged_results

    @staticmethod
    def flatten_sequence(arrs: list[list]):
        """Flatten a list of lists into a list

        Parameters
        ----------
        arrs : list[list]
            list of lists

        Returns
        -------
        list
            flattened list
        """
        return [item for sublist in arrs for item in sublist]

    @staticmethod
    def ndarray_sum(arrs: list[np.ndarray]):
        """sums an ndarray along ``axis=0``

        Parameters
        ----------
        arrs : list[np.ndarray]
            list of input arrays. Must have the same shape.

        Returns
        -------
        np.ndarray
            sum of input arrays
        """
        return np.array(arrs).sum(axis=0)

    @staticmethod
    def ndarray_mean(arrs: list[np.ndarray]):
        """calculates mean of input ndarrays along ``axis=0``

        Parameters
        ----------
        arrs : list[np.ndarray]
            list of input arrays. Must have the same shape.

        Returns
        -------
        np.ndarray
            mean of input arrays
        """
        return np.array(arrs).mean(axis=0)

    @staticmethod
    def float_mean(floats: list[float]):
        """calculates mean of input float values

        Parameters
        ----------
        floats : list[float]
            list of float values

        Returns
        -------
        float
            mean value
        """
        return np.array(floats).mean()

    @staticmethod
    def ndarray_hstack(arrs: list[np.ndarray]):
        """Performs horizontal stack of input arrays

        Parameters
        ----------
        arrs : list[np.ndarray]
            input numpy arrays

        Returns
        -------
        np.ndarray
            result of stacking
        """
        return np.hstack(arrs)

    @staticmethod
    def ndarray_vstack(arrs: list[np.ndarray]):
        """Performs vertical stack of input arrays

        Parameters
        ----------
        arrs : list[np.ndarray]
            input numpy arrays

        Returns
        -------
        np.ndarray
            result of stacking
        """
        return np.vstack(arrs)