1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
import pytest
from MDAnalysis.analysis import backends
from MDAnalysis.lib.util import is_installed
def square(x: int):
return x**2
def noop(x):
return x
def upper(s):
return s.upper()
class Test_Backends:
@pytest.mark.parametrize(
"backend_cls,n_workers",
[
(backends.BackendBase, -1),
(backends.BackendSerial, None),
(backends.BackendMultiprocessing, "string"),
(backends.BackendDask, ()),
],
)
def test_fails_incorrect_n_workers(self, backend_cls, n_workers):
with pytest.raises(ValueError):
_ = backend_cls(n_workers=n_workers)
@pytest.mark.parametrize(
"func,iterable,answer",
[
(square, (1, 2, 3), [1, 4, 9]),
(square, (), []),
(noop, list(range(10)), list(range(10))),
(upper, "asdf", list("ASDF")),
],
)
def test_all_backends_give_correct_results(self, func, iterable, answer):
backend_instances = [
backends.BackendMultiprocessing(n_workers=2),
backends.BackendSerial(n_workers=1),
]
if is_installed("dask"):
backend_instances.append(backends.BackendDask(n_workers=2))
backends_dict = {b: b.apply(func, iterable) for b in backend_instances}
for answ in backends_dict.values():
assert answ == answer
@pytest.mark.parametrize(
"backend_cls,params,warning_message",
[
(
backends.BackendSerial,
{"n_workers": 5},
"n_workers is ignored when executing with backend='serial'",
),
],
)
def test_get_warnings(self, backend_cls, params, warning_message):
with pytest.warns(UserWarning, match=warning_message):
backend_cls(**params)
@pytest.mark.parametrize(
"backend_cls,params,error_message",
[
pytest.param(
backends.BackendDask,
{"n_workers": 2},
(
"module 'dask' is missing. Please install 'dask': "
"https://docs.dask.org/en/stable/install.html"
),
marks=pytest.mark.skipif(
is_installed("dask"), reason="dask is installed"
),
)
],
)
def test_get_errors(self, backend_cls, params, error_message):
with pytest.raises(ValueError, match=error_message):
backend_cls(**params)
|