File: test_progressbar.py

package info (click to toggle)
python-mne 1.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 131,492 kB
  • sloc: python: 213,302; javascript: 12,910; sh: 447; makefile: 144
file content (130 lines) | stat: -rw-r--r-- 4,354 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from pathlib import Path

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from mne.parallel import parallel_func
from mne.utils import ProgressBar, array_split_idx, catch_logging, use_log_level


def test_progressbar(monkeypatch):
    """Test progressbar class."""
    a = np.arange(10)
    pbar = ProgressBar(a)
    assert a is pbar.iterable
    assert pbar.max_value == 10

    pbar = ProgressBar(10)
    assert pbar.max_value == 10
    assert pbar.iterable is None

    # Make sure that non-iterable input raises an error
    def iter_func(a):
        for ii in a:
            pass

    with pytest.raises(TypeError, match="not iterable"):
        iter_func(pbar)

    # Make sure different progress bars can be used
    monkeypatch.setenv("MNE_TQDM", "tqdm")
    with catch_logging("debug") as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert "Using ProgressBar with tqdm\n" in log
    monkeypatch.setenv("MNE_TQDM", "broken")
    with pytest.raises(ValueError, match="Invalid value for the"):
        ProgressBar(np.arange(3))
    monkeypatch.setenv("MNE_TQDM", "tqdm.broken")
    with pytest.raises(ValueError, match="Unknown tqdm"):
        ProgressBar(np.arange(3))
    # off
    monkeypatch.setenv("MNE_TQDM", "off")
    with catch_logging("debug") as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert "Using ProgressBar with off\n" == log


def _identity(x):
    return x


def test_progressbar_parallel_basic(capsys):
    """Test ProgressBar with parallel computing, basic version."""
    assert capsys.readouterr().out == ""
    parallel, p_fun, _ = parallel_func(_identity, total=10, n_jobs=1, verbose=True)
    with use_log_level(True):
        out = parallel(p_fun(x) for x in range(10))
    assert out == list(range(10))
    cap = capsys.readouterr()
    out = cap.err
    assert "100%" in out


def _identity_block(x, pb):
    for ii in range(len(x)):
        pb.update(ii + 1)
    return x


def test_progressbar_parallel_advanced(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ""
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block, n_jobs=1, verbose=False)
    arr = np.arange(10)
    with use_log_level(True):
        with ProgressBar(len(arr)) as pb:
            out = parallel(
                p_fun(x, pb.subset(pb_idx)) for pb_idx, x in array_split_idx(arr, 2)
            )
            assert Path(pb._mmap_fname).is_file()
            sum_ = np.memmap(pb._mmap_fname, dtype="bool", mode="r", shape=10).sum()
            assert sum_ == len(arr)
    assert not Path(pb._mmap_fname).is_file(), "__exit__ not called?"
    out = np.concatenate(out)
    assert_array_equal(out, arr)
    cap = capsys.readouterr()
    out = cap.err
    assert "100%" in out


def _identity_block_wide(x, pb):
    for ii in range(len(x)):
        for jj in range(2):
            pb.update(ii * 2 + jj + 1)
    return x, pb.idx


def test_progressbar_parallel_more(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ""
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block_wide, n_jobs=1, verbose=False)
    arr = np.arange(10)
    with use_log_level(True):
        with ProgressBar(len(arr) * 2) as pb:
            out = parallel(
                p_fun(x, pb.subset(pb_idx))
                for pb_idx, x in array_split_idx(arr, 2, n_per_split=2)
            )
            idxs = np.concatenate([o[1] for o in out])
            assert_array_equal(idxs, np.arange(len(arr) * 2))
            out = np.concatenate([o[0] for o in out])
            assert Path(pb._mmap_fname).is_file()
            sum_ = np.memmap(
                pb._mmap_fname, dtype="bool", mode="r", shape=len(arr) * 2
            ).sum()
            assert sum_ == len(arr) * 2
    assert not Path(pb._mmap_fname).is_file(), "__exit__ not called?"
    cap = capsys.readouterr()
    out = cap.err
    assert "100%" in out