File: test_backends.py

package info (click to toggle)
mdanalysis 2.10.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,696 kB
  • sloc: python: 92,135; ansic: 8,156; makefile: 215; sh: 138
file content (86 lines) | stat: -rw-r--r-- 2,541 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pytest
from MDAnalysis.analysis import backends
from MDAnalysis.lib.util import is_installed


def square(x: int):
    return x**2


def noop(x):
    return x


def upper(s):
    return s.upper()


class Test_Backends:

    @pytest.mark.parametrize(
        "backend_cls,n_workers",
        [
            (backends.BackendBase, -1),
            (backends.BackendSerial, None),
            (backends.BackendMultiprocessing, "string"),
            (backends.BackendDask, ()),
        ],
    )
    def test_fails_incorrect_n_workers(self, backend_cls, n_workers):
        with pytest.raises(ValueError):
            _ = backend_cls(n_workers=n_workers)

    @pytest.mark.parametrize(
        "func,iterable,answer",
        [
            (square, (1, 2, 3), [1, 4, 9]),
            (square, (), []),
            (noop, list(range(10)), list(range(10))),
            (upper, "asdf", list("ASDF")),
        ],
    )
    def test_all_backends_give_correct_results(self, func, iterable, answer):
        backend_instances = [
            backends.BackendMultiprocessing(n_workers=2),
            backends.BackendSerial(n_workers=1),
        ]
        if is_installed("dask"):
            backend_instances.append(backends.BackendDask(n_workers=2))

        backends_dict = {b: b.apply(func, iterable) for b in backend_instances}
        for answ in backends_dict.values():
            assert answ == answer

    @pytest.mark.parametrize(
        "backend_cls,params,warning_message",
        [
            (
                backends.BackendSerial,
                {"n_workers": 5},
                "n_workers is ignored when executing with backend='serial'",
            ),
        ],
    )
    def test_get_warnings(self, backend_cls, params, warning_message):
        with pytest.warns(UserWarning, match=warning_message):
            backend_cls(**params)

    @pytest.mark.parametrize(
        "backend_cls,params,error_message",
        [
            pytest.param(
                backends.BackendDask,
                {"n_workers": 2},
                (
                    "module 'dask' is missing. Please install 'dask': "
                    "https://docs.dask.org/en/stable/install.html"
                ),
                marks=pytest.mark.skipif(
                    is_installed("dask"), reason="dask is installed"
                ),
            )
        ],
    )
    def test_get_errors(self, backend_cls, params, error_message):
        with pytest.raises(ValueError, match=error_message):
            backend_cls(**params)