File: test_parallel.py

package info (click to toggle)
python-mne 1.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 131,492 kB
  • sloc: python: 213,302; javascript: 12,910; sh: 447; makefile: 144
file content (52 lines) | stat: -rw-r--r-- 1,313 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import multiprocessing
import os
from contextlib import nullcontext

import pytest

from mne.parallel import parallel_func


@pytest.mark.parametrize(
    "n_jobs",
    [
        None,
        1,
        -1,
        "loky 2",
        "threading 3",
        "multiprocessing 4",
    ],
)
def test_parallel_func(n_jobs):
    """Test Parallel wrapping."""
    joblib = pytest.importorskip("joblib")
    if os.getenv("MNE_FORCE_SERIAL", "").lower() in ("true", "1"):
        pytest.skip("MNE_FORCE_SERIAL cannot be set")

    def fun(x):
        return x * 2

    if isinstance(n_jobs, str):
        backend, n_jobs = n_jobs.split()
        n_jobs = want_jobs = int(n_jobs)
        try:
            func = joblib.parallel_config
        except AttributeError:
            # joblib < 1.3
            func = joblib.parallel_backend
        ctx = func(backend, n_jobs=n_jobs)
        n_jobs = None
    else:
        ctx = nullcontext()
        if n_jobs is not None and n_jobs < 0:
            want_jobs = multiprocessing.cpu_count() + 1 + n_jobs
        else:
            want_jobs = 1
    with ctx:
        parallel, p_fun, got_jobs = parallel_func(fun, n_jobs, verbose="debug")
    assert got_jobs == want_jobs