File: test_renderer.py

package info (click to toggle)
python-mne 1.3.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 100,172 kB
  • sloc: python: 166,349; pascal: 3,602; javascript: 1,472; sh: 334; makefile: 236
file content (221 lines) | stat: -rw-r--r-- 7,607 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
#          Eric Larson <larson.eric.d@gmail.com>
#          Joan Massich <mailsik@gmail.com>
#          Guillaume Favelier <guillaume.favelier@gmail.com>
#
# License: Simplified BSD

import os
import sys

import pytest
import numpy as np

from mne.utils import run_subprocess
from mne.viz import set_3d_backend, get_3d_backend, Figure3D
from mne.viz.backends.renderer import _get_renderer
from mne.viz.backends.tests._utils import skips_if_not_pyvistaqt
from mne.viz.backends._utils import ALLOWED_QUIVER_MODES


@pytest.mark.parametrize('backend', [
    pytest.param('pyvistaqt', marks=skips_if_not_pyvistaqt),
    pytest.param('foo', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_backend_environment_setup(backend, monkeypatch):
    """Test set up 3d backend based on env."""
    monkeypatch.setenv("MNE_3D_BACKEND", backend)
    monkeypatch.setattr(
        'mne.viz.backends.renderer.MNE_3D_BACKEND', None)
    assert os.environ['MNE_3D_BACKEND'] == backend  # just double-check

    # reload the renderer to check if the 3d backend selection by
    # environment variable has been updated correctly
    from mne.viz.backends import renderer
    renderer.set_3d_backend(backend)
    assert renderer.MNE_3D_BACKEND == backend
    assert renderer.get_3d_backend() == backend


def test_3d_functions(renderer):
    """Test figure management functions."""
    fig = renderer.create_3d_figure((300, 300))
    assert isinstance(fig, Figure3D)
    wrap_renderer = renderer.backend._Renderer(fig=fig)
    wrap_renderer.sphere(np.array([0., 0., 0.]), 'w', 1.)
    renderer.backend._check_3d_figure(fig)
    renderer.set_3d_view(figure=fig, azimuth=None, elevation=None,
                         focalpoint=(0., 0., 0.), distance=None)
    renderer.set_3d_title(figure=fig, title='foo')
    renderer.backend._take_3d_screenshot(figure=fig)
    renderer.close_3d_figure(fig)
    renderer.close_all_3d_figures()


def test_3d_backend(renderer):
    """Test default plot."""
    # set data
    win_size = (600, 600)
    win_color = 'black'

    tet_size = 1.0
    tet_x = np.array([0, tet_size, 0, 0])
    tet_y = np.array([0, 0, tet_size, 0])
    tet_z = np.array([0, 0, 0, tet_size])
    tet_indices = np.array([[0, 1, 2],
                            [0, 1, 3],
                            [0, 2, 3],
                            [1, 2, 3]])
    tet_color = 'white'

    sph_center = np.column_stack((tet_x, tet_y, tet_z))
    sph_color = 'red'
    sph_scale = tet_size / 3.0

    ct_scalars = np.array([0.0, 0.0, 0.0, 1.0])
    ct_levels = [0.2, 0.4, 0.6, 0.8]
    ct_surface = {
        "rr": sph_center,
        "tris": tet_indices
    }

    qv_color = 'blue'
    qv_scale = tet_size / 2.0
    qv_center = np.array([np.mean((sph_center[va, :],
                                   sph_center[vb, :],
                                   sph_center[vc, :]), axis=0)
                         for (va, vb, vc) in tet_indices])
    center = np.mean(qv_center, axis=0)
    qv_dir = qv_center - center
    qv_scale_mode = 'scalar'
    qv_scalars = np.linspace(1.0, 2.0, 4)

    txt_x = 0.0
    txt_y = 0.0
    txt_text = "renderer"
    txt_size = 14

    cam_distance = 5 * tet_size

    # init scene
    rend = renderer.create_3d_figure(
        size=win_size,
        bgcolor=win_color,
        smooth_shading=True,
        scene=False,
    )
    for interaction in ('terrain', 'trackball'):
        rend.set_interaction(interaction)

    # use mesh
    mesh_data = rend.mesh(
        x=tet_x,
        y=tet_y,
        z=tet_z,
        triangles=tet_indices,
        color=tet_color,
    )
    rend.remove_mesh(mesh_data)

    # use contour
    rend.contour(surface=ct_surface, scalars=ct_scalars,
                 contours=ct_levels, kind='line')
    rend.contour(surface=ct_surface, scalars=ct_scalars,
                 contours=ct_levels, kind='tube')

    # use sphere
    rend.sphere(center=sph_center, color=sph_color,
                scale=sph_scale, radius=1.0)

    # use quiver3d
    kwargs = dict(
        x=qv_center[:, 0],
        y=qv_center[:, 1],
        z=qv_center[:, 2],
        u=qv_dir[:, 0],
        v=qv_dir[:, 1],
        w=qv_dir[:, 2],
        color=qv_color,
        scale=qv_scale,
        scale_mode=qv_scale_mode,
        scalars=qv_scalars,
    )
    for mode in ALLOWED_QUIVER_MODES:
        rend.quiver3d(mode=mode, **kwargs)
    with pytest.raises(ValueError, match='Invalid value'):
        rend.quiver3d(mode='foo', **kwargs)

    # use tube
    rend.tube(origin=np.array([[0, 0, 0]]),
              destination=np.array([[0, 1, 0]]))
    _, tube = rend.tube(origin=np.array([[1, 0, 0]]),
                        destination=np.array([[1, 1, 0]]),
                        scalars=np.array([[1.0, 1.0]]))

    # scalar bar
    rend.scalarbar(source=tube, title="Scalar Bar",
                   bgcolor=[1, 1, 1])

    # use text
    rend.text2d(x_window=txt_x, y_window=txt_y, text=txt_text,
                size=txt_size, justification='right')
    rend.text3d(x=0, y=0, z=0, text=txt_text, scale=1.0)
    rend.set_camera(azimuth=180.0, elevation=90.0,
                    distance=cam_distance,
                    focalpoint=center)
    rend.reset_camera()
    rend.show()


def test_get_3d_backend(renderer):
    """Test get_3d_backend function call for side-effects."""
    # Test twice to ensure the first call had no side-effect
    orig_backend = renderer.MNE_3D_BACKEND
    assert renderer.get_3d_backend() == orig_backend
    assert renderer.get_3d_backend() == orig_backend


def test_renderer(renderer, monkeypatch):
    """Test that renderers are available on demand."""
    backend = renderer.get_3d_backend()
    cmd = [sys.executable, '-uc',
           'import mne; mne.viz.create_3d_figure((800, 600), show=True); '
           'backend = mne.viz.get_3d_backend(); '
           'assert backend == %r, backend' % (backend,)]
    monkeypatch.setenv('MNE_3D_BACKEND', backend)
    run_subprocess(cmd)


def test_set_3d_backend_bad(monkeypatch, tmp_path):
    """Test that the error emitted when a bad backend name is used."""
    match = "Allowed values are 'pyvistaqt' and 'notebook'"
    with pytest.raises(ValueError, match=match):
        set_3d_backend('invalid')

    # gh-9607
    def fail(x):
        raise ModuleNotFoundError(x)
    monkeypatch.setattr('mne.viz.backends.renderer._reload_backend', fail)
    monkeypatch.setattr(
        'mne.viz.backends.renderer.MNE_3D_BACKEND', None)
    match = 'Could not load any valid 3D.*\npyvistaqt: .*'
    assert get_3d_backend() is None
    with pytest.raises(RuntimeError, match=match):
        _get_renderer()


def test_3d_warning(renderer_pyvistaqt, monkeypatch):
    """Test that warnings are emitted for old Mesa."""
    fig = renderer_pyvistaqt.create_3d_figure((800, 600))
    _is_mesa = renderer_pyvistaqt.backend._is_mesa
    plotter = fig.plotter
    good = 'OpenGL renderer string: OpenGL 3.3 (Core Profile) Mesa 20.0.8 via llvmpipe (LLVM 10.0.0, 256 bits)\n'  # noqa
    bad = 'OpenGL renderer string: OpenGL 3.3 (Core Profile) Mesa 18.3.4 via llvmpipe (LLVM 7.0, 256 bits)\n'  # noqa
    monkeypatch.setattr(plotter.ren_win, 'ReportCapabilities', lambda: good)
    assert _is_mesa(plotter)
    monkeypatch.setattr(plotter.ren_win, 'ReportCapabilities', lambda: bad)
    with pytest.warns(RuntimeWarning, match=r'18\.3\.4 is too old'):
        assert _is_mesa(plotter)
    non = 'OpenGL 4.1 Metal - 76.3 via Apple M1 Pro\n'
    monkeypatch.setattr(plotter.ren_win, 'ReportCapabilities', lambda: non)
    assert not _is_mesa(plotter)