File: test_progressbar.py

package info (click to toggle)
python-mne 1.3.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 100,172 kB
  • sloc: python: 166,349; pascal: 3,602; javascript: 1,472; sh: 334; makefile: 236
file content (132 lines) | stat: -rw-r--r-- 4,503 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# -*- coding: utf-8 -*-
# Authors: Eric Larson <larson.eric.d@gmail.com>
#
# License: BSD-3-Clause

import os.path as op

import numpy as np
from numpy.testing import assert_array_equal
import pytest

from mne.parallel import parallel_func
from mne.utils import (ProgressBar, array_split_idx, use_log_level,
                       catch_logging)


def test_progressbar(monkeypatch):
    """Test progressbar class."""
    a = np.arange(10)
    pbar = ProgressBar(a)
    assert a is pbar.iterable
    assert pbar.max_value == 10

    pbar = ProgressBar(10)
    assert pbar.max_value == 10
    assert pbar.iterable is None

    # Make sure that non-iterable input raises an error
    def iter_func(a):
        for ii in a:
            pass
    with pytest.raises(TypeError, match='not iterable'):
        iter_func(pbar)

    # Make sure different progress bars can be used
    monkeypatch.setenv('MNE_TQDM', 'tqdm')
    with catch_logging('debug') as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert 'Using ProgressBar with tqdm\n' in log
    monkeypatch.setenv('MNE_TQDM', 'broken')
    with pytest.raises(ValueError, match='Invalid value for the'):
        ProgressBar(np.arange(3))
    monkeypatch.setenv('MNE_TQDM', 'tqdm.broken')
    with pytest.raises(ValueError, match='Unknown tqdm'):
        ProgressBar(np.arange(3))
    # off
    monkeypatch.setenv('MNE_TQDM', 'off')
    with catch_logging('debug') as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert 'Using ProgressBar with off\n' == log


def _identity(x):
    return x


def test_progressbar_parallel_basic(capsys):
    """Test ProgressBar with parallel computing, basic version."""
    assert capsys.readouterr().out == ''
    parallel, p_fun, _ = parallel_func(_identity, total=10, n_jobs=1,
                                       verbose=True)
    with use_log_level(True):
        out = parallel(p_fun(x) for x in range(10))
    assert out == list(range(10))
    cap = capsys.readouterr()
    out = cap.err
    assert '100%' in out


def _identity_block(x, pb):
    for ii in range(len(x)):
        pb.update(ii + 1)
    return x


def test_progressbar_parallel_advanced(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ''
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block, n_jobs=1,
                                       verbose=False)
    arr = np.arange(10)
    with use_log_level(True):
        with ProgressBar(len(arr)) as pb:
            out = parallel(p_fun(x, pb.subset(pb_idx))
                           for pb_idx, x in array_split_idx(arr, 2))
            assert op.isfile(pb._mmap_fname)
            sum_ = np.memmap(pb._mmap_fname, dtype='bool', mode='r',
                             shape=10).sum()
            assert sum_ == len(arr)
    assert not op.isfile(pb._mmap_fname), '__exit__ not called?'
    out = np.concatenate(out)
    assert_array_equal(out, arr)
    cap = capsys.readouterr()
    out = cap.err
    assert '100%' in out


def _identity_block_wide(x, pb):
    for ii in range(len(x)):
        for jj in range(2):
            pb.update(ii * 2 + jj + 1)
    return x, pb.idx


def test_progressbar_parallel_more(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ''
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block_wide, n_jobs=1,
                                       verbose=False)
    arr = np.arange(10)
    with use_log_level(True):
        with ProgressBar(len(arr) * 2) as pb:
            out = parallel(p_fun(x, pb.subset(pb_idx))
                           for pb_idx, x in array_split_idx(
                               arr, 2, n_per_split=2))
            idxs = np.concatenate([o[1] for o in out])
            assert_array_equal(idxs, np.arange(len(arr) * 2))
            out = np.concatenate([o[0] for o in out])
            assert op.isfile(pb._mmap_fname)
            sum_ = np.memmap(pb._mmap_fname, dtype='bool', mode='r',
                             shape=len(arr) * 2).sum()
            assert sum_ == len(arr) * 2
    assert not op.isfile(pb._mmap_fname), '__exit__ not called?'
    cap = capsys.readouterr()
    out = cap.err
    assert '100%' in out