File: test_utils.py

package info (click to toggle)
python-mne 0.17%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 95,104 kB
  • sloc: python: 110,639; makefile: 222; sh: 15
file content (182 lines) | stat: -rw-r--r-- 6,317 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#
# License: Simplified BSD

import os.path as op

import numpy as np
from numpy.testing import assert_allclose
import pytest

from mne.viz.utils import (compare_fiff, _fake_click, _compute_scalings,
                           _validate_if_list_of_axes, _get_color_list,
                           _setup_vmin_vmax, center_cmap)
from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap
from mne.utils import run_tests_if_main
from mne.io import read_raw_fif
from mne.event import read_events
from mne.epochs import Epochs

# Set our plotters to test mode
import matplotlib
matplotlib.use('Agg')  # for testing don't use X server

base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
cov_fname = op.join(base_dir, 'test-cov.fif')
ev_fname = op.join(base_dir, 'test_raw-eve.fif')


def test_setup_vmin_vmax_warns():
    """Test that _setup_vmin_vmax warns properly."""
    expected_msg = r'\(min=0.0, max=1\) range.*minimum of data is -1'
    with pytest.warns(UserWarning, match=expected_msg):
        _setup_vmin_vmax(data=[-1, 0], vmin=None, vmax=None, norm=True)


def test_get_color_list():
    """Test getting a colormap from rcParams."""
    colors = _get_color_list()
    assert isinstance(colors, list)
    colors_no_red = _get_color_list(annotations=True)
    assert '#ff0000' not in colors_no_red


def test_mne_analyze_colormap():
    """Test mne_analyze_colormap."""
    pytest.raises(ValueError, mne_analyze_colormap, [0])
    pytest.raises(ValueError, mne_analyze_colormap, [-1, 1, 2])
    pytest.raises(ValueError, mne_analyze_colormap, [0, 2, 1])


def test_compare_fiff():
    """Test compare_fiff."""
    import matplotlib.pyplot as plt
    compare_fiff(raw_fname, cov_fname, read_limit=0, show=False)
    plt.close('all')


def test_clickable_image():
    """Test the ClickableImage class."""
    # Gen data and create clickable image
    import matplotlib.pyplot as plt
    im = np.random.RandomState(0).randn(100, 100)
    clk = ClickableImage(im)
    clicks = [(12, 8), (46, 48), (10, 24)]

    # Generate clicks
    for click in clicks:
        _fake_click(clk.fig, clk.ax, click, xform='data')
    assert_allclose(np.array(clicks), np.array(clk.coords))
    assert (len(clicks) == len(clk.coords))

    # Exporting to layout
    lt = clk.to_layout()
    assert (lt.pos.shape[0] == len(clicks))
    assert_allclose(lt.pos[1, 0] / lt.pos[2, 0],
                    clicks[1][0] / float(clicks[2][0]))
    clk.plot_clicks()
    plt.close('all')


def test_add_background_image():
    """Test adding background image to a figure."""
    import matplotlib.pyplot as plt
    rng = np.random.RandomState(0)
    for ii in range(2):
        f, axs = plt.subplots(1, 2)
        x, y = rng.randn(2, 10)
        im = rng.randn(10, 10)
        axs[0].scatter(x, y)
        axs[1].scatter(y, x)
        for ax in axs:
            ax.set_aspect(1)

        # Background without changing aspect
        if ii == 0:
            ax_im = add_background_image(f, im)
            return
            assert (ax_im.get_aspect() == 'auto')
            for ax in axs:
                assert (ax.get_aspect() == 1)
        else:
            # Background with changing aspect
            ax_im_asp = add_background_image(f, im, set_ratios='auto')
            assert (ax_im_asp.get_aspect() == 'auto')
            for ax in axs:
                assert (ax.get_aspect() == 'auto')
        plt.close('all')

    # Make sure passing None as image returns None
    f, axs = plt.subplots(1, 2)
    assert (add_background_image(f, None) is None)
    plt.close('all')


def test_auto_scale():
    """Test auto-scaling of channels for quick plotting."""
    raw = read_raw_fif(raw_fname)
    epochs = Epochs(raw, read_events(ev_fname))
    rand_data = np.random.randn(10, 100)

    for inst in [raw, epochs]:
        scale_grad = 1e10
        scalings_def = dict([('eeg', 'auto'), ('grad', scale_grad),
                             ('stim', 'auto')])

        # Test for wrong inputs
        pytest.raises(ValueError, inst.plot, scalings='foo')
        pytest.raises(ValueError, _compute_scalings, 'foo', inst)

        # Make sure compute_scalings doesn't change anything not auto
        scalings_new = _compute_scalings(scalings_def, inst)
        assert (scale_grad == scalings_new['grad'])
        assert (scalings_new['eeg'] != 'auto')

    pytest.raises(ValueError, _compute_scalings, scalings_def, rand_data)
    epochs = epochs[0].load_data()
    epochs.pick_types(eeg=True, meg=False)
    pytest.raises(ValueError, _compute_scalings,
                  dict(grad='auto'), epochs)


def test_validate_if_list_of_axes():
    """Test validation of axes."""
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(2, 2)
    pytest.raises(ValueError, _validate_if_list_of_axes, ax)
    ax_flat = ax.ravel()
    ax = ax.ravel().tolist()
    _validate_if_list_of_axes(ax_flat)
    _validate_if_list_of_axes(ax_flat, 4)
    pytest.raises(ValueError, _validate_if_list_of_axes, ax_flat, 5)
    pytest.raises(ValueError, _validate_if_list_of_axes, ax, 3)
    pytest.raises(ValueError, _validate_if_list_of_axes, 'error')
    pytest.raises(ValueError, _validate_if_list_of_axes, ['error'] * 2)
    pytest.raises(ValueError, _validate_if_list_of_axes, ax[0])
    pytest.raises(ValueError, _validate_if_list_of_axes, ax, 3)
    ax_flat[2] = 23
    pytest.raises(ValueError, _validate_if_list_of_axes, ax_flat)
    _validate_if_list_of_axes(ax, 4)
    plt.close('all')


def test_center_cmap():
    """Test centering of colormap."""
    import matplotlib.cm as cm
    from matplotlib.colors import LinearSegmentedColormap
    from matplotlib.pyplot import Normalize
    cmap = center_cmap(cm.get_cmap("RdBu"), -5, 10)

    assert isinstance(cmap, LinearSegmentedColormap)

    # get new colors for values -5 (red), 0 (white), and 10 (blue)
    new_colors = cmap(Normalize(-5, 10)([-5, 0, 10]))
    # get original colors for 0 (red), 0.5 (white), and 1 (blue)
    reference = cm.RdBu([0., 0.5, 1.])
    assert_allclose(new_colors, reference)
    # new and old colors at 0.5 must be different
    assert not np.allclose(cmap(0.5), reference[1])


run_tests_if_main()