File: test_proj.py

package info (click to toggle)
python-mne 0.13.1%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 92,032 kB
  • ctags: 8,249
  • sloc: python: 84,750; makefile: 205; sh: 15
file content (312 lines) | stat: -rw-r--r-- 12,837 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import os.path as op
from nose.tools import assert_true, assert_raises
import warnings

import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_allclose,
                           assert_equal)

import copy as cp

import mne
from mne.datasets import testing
from mne import pick_types
from mne.io import read_raw_fif
from mne import compute_proj_epochs, compute_proj_evoked, compute_proj_raw
from mne.io.proj import (make_projector, activate_proj,
                         _needs_eeg_average_ref_proj)
from mne.proj import (read_proj, write_proj, make_eeg_average_ref_proj,
                      _has_eeg_average_ref_proj)
from mne import read_events, Epochs, sensitivity_map, read_source_estimate
from mne.tests.common import assert_naming
from mne.utils import _TempDir, run_tests_if_main, slow_test

warnings.simplefilter('always')  # enable b/c these tests throw warnings

base_dir = op.join(op.dirname(__file__), '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
event_fname = op.join(base_dir, 'test-eve.fif')
proj_fname = op.join(base_dir, 'test-proj.fif')
proj_gz_fname = op.join(base_dir, 'test-proj.fif.gz')
bads_fname = op.join(base_dir, 'test_bads.txt')

sample_path = op.join(testing.data_path(download=False), 'MEG', 'sample')
fwd_fname = op.join(sample_path, 'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif')
sensmap_fname = op.join(sample_path,
                        'sample_audvis_trunc-%s-oct-4-fwd-sensmap-%s.w')

eog_fname = op.join(sample_path, 'sample_audvis_eog-proj.fif')
ecg_fname = op.join(sample_path, 'sample_audvis_ecg-proj.fif')


def test_bad_proj():
    """Test dealing with bad projection application."""
    raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
    events = read_events(event_fname)
    picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
    picks = picks[2:9:3]
    _check_warnings(raw, events, picks)
    # still bad
    raw.pick_channels([raw.ch_names[ii] for ii in picks])
    _check_warnings(raw, events, np.arange(len(raw.ch_names)))
    # "fixed"
    raw.info.normalize_proj()  # avoid projection warnings
    _check_warnings(raw, events, np.arange(len(raw.ch_names)), count=0)


def _check_warnings(raw, events, picks, count=3):
    """Helper to count warnings."""
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        Epochs(raw, events, dict(aud_l=1, vis_l=3),
               -0.2, 0.5, picks=picks, preload=True, proj=True,
               add_eeg_ref=False)
    assert_equal(len(w), count)
    for ww in w:
        assert_true('dangerous' in str(ww.message))


@testing.requires_testing_data
def test_sensitivity_maps():
    """Test sensitivity map computation."""
    fwd = mne.read_forward_solution(fwd_fname, surf_ori=True)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        projs = read_proj(eog_fname)
        projs.extend(read_proj(ecg_fname))
    decim = 6
    for ch_type in ['eeg', 'grad', 'mag']:
        w = read_source_estimate(sensmap_fname % (ch_type, 'lh')).data
        stc = sensitivity_map(fwd, projs=None, ch_type=ch_type,
                              mode='free', exclude='bads')
        assert_array_almost_equal(stc.data, w, decim)
        assert_true(stc.subject == 'sample')
        # let's just make sure the others run
        if ch_type == 'grad':
            # fixed (2)
            w = read_source_estimate(sensmap_fname % (ch_type, '2-lh')).data
            stc = sensitivity_map(fwd, projs=None, mode='fixed',
                                  ch_type=ch_type, exclude='bads')
            assert_array_almost_equal(stc.data, w, decim)
        if ch_type == 'mag':
            # ratio (3)
            w = read_source_estimate(sensmap_fname % (ch_type, '3-lh')).data
            stc = sensitivity_map(fwd, projs=None, mode='ratio',
                                  ch_type=ch_type, exclude='bads')
            assert_array_almost_equal(stc.data, w, decim)
        if ch_type == 'eeg':
            # radiality (4), angle (5), remaining (6), and  dampening (7)
            modes = ['radiality', 'angle', 'remaining', 'dampening']
            ends = ['4-lh', '5-lh', '6-lh', '7-lh']
            for mode, end in zip(modes, ends):
                w = read_source_estimate(sensmap_fname % (ch_type, end)).data
                stc = sensitivity_map(fwd, projs=projs, mode=mode,
                                      ch_type=ch_type, exclude='bads')
                assert_array_almost_equal(stc.data, w, decim)

    # test corner case for EEG
    stc = sensitivity_map(fwd, projs=[make_eeg_average_ref_proj(fwd['info'])],
                          ch_type='eeg', exclude='bads')
    # test corner case for projs being passed but no valid ones (#3135)
    assert_raises(ValueError, sensitivity_map, fwd, projs=None, mode='angle')
    assert_raises(RuntimeError, sensitivity_map, fwd, projs=[], mode='angle')
    # test volume source space
    fname = op.join(sample_path, 'sample_audvis_trunc-meg-vol-7-fwd.fif')
    fwd = mne.read_forward_solution(fname)
    sensitivity_map(fwd)


def test_compute_proj_epochs():
    """Test SSP computation on epochs."""
    tempdir = _TempDir()
    event_id, tmin, tmax = 1, -0.2, 0.3

    raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
    events = read_events(event_fname)
    bad_ch = 'MEG 2443'
    picks = pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False,
                       exclude=[])
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=None, proj=False, add_eeg_ref=False)

    evoked = epochs.average()
    projs = compute_proj_epochs(epochs, n_grad=1, n_mag=1, n_eeg=0, n_jobs=1)
    write_proj(op.join(tempdir, 'test-proj.fif.gz'), projs)
    for p_fname in [proj_fname, proj_gz_fname,
                    op.join(tempdir, 'test-proj.fif.gz')]:
        projs2 = read_proj(p_fname)

        assert_true(len(projs) == len(projs2))

        for p1, p2 in zip(projs, projs2):
            assert_true(p1['desc'] == p2['desc'])
            assert_true(p1['data']['col_names'] == p2['data']['col_names'])
            assert_true(p1['active'] == p2['active'])
            # compare with sign invariance
            p1_data = p1['data']['data'] * np.sign(p1['data']['data'][0, 0])
            p2_data = p2['data']['data'] * np.sign(p2['data']['data'][0, 0])
            if bad_ch in p1['data']['col_names']:
                bad = p1['data']['col_names'].index('MEG 2443')
                mask = np.ones(p1_data.size, dtype=np.bool)
                mask[bad] = False
                p1_data = p1_data[:, mask]
                p2_data = p2_data[:, mask]
            corr = np.corrcoef(p1_data, p2_data)[0, 1]
            assert_array_almost_equal(corr, 1.0, 5)
            if p2['explained_var']:
                assert_array_almost_equal(p1['explained_var'],
                                          p2['explained_var'])

    # test that you can compute the projection matrix
    projs = activate_proj(projs)
    proj, nproj, U = make_projector(projs, epochs.ch_names, bads=[])

    assert_true(nproj == 2)
    assert_true(U.shape[1] == 2)

    # test that you can save them
    epochs.info['projs'] += projs
    evoked = epochs.average()
    evoked.save(op.join(tempdir, 'foo-ave.fif'))

    projs = read_proj(proj_fname)

    projs_evoked = compute_proj_evoked(evoked, n_grad=1, n_mag=1, n_eeg=0)
    assert_true(len(projs_evoked) == 2)
    # XXX : test something

    # test parallelization
    projs = compute_proj_epochs(epochs, n_grad=1, n_mag=1, n_eeg=0, n_jobs=2,
                                desc_prefix='foobar')
    assert_true(all('foobar' in x['desc'] for x in projs))
    projs = activate_proj(projs)
    proj_par, _, _ = make_projector(projs, epochs.ch_names, bads=[])
    assert_allclose(proj, proj_par, rtol=1e-8, atol=1e-16)

    # test warnings on bad filenames
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        proj_badname = op.join(tempdir, 'test-bad-name.fif.gz')
        write_proj(proj_badname, projs)
        read_proj(proj_badname)
    assert_naming(w, 'test_proj.py', 2)


@slow_test
def test_compute_proj_raw():
    """Test SSP computation on raw"""
    tempdir = _TempDir()
    # Test that the raw projectors work
    raw_time = 2.5  # Do shorter amount for speed
    raw = read_raw_fif(raw_fname, add_eeg_ref=False).crop(0, raw_time)
    raw.load_data()
    for ii in (0.25, 0.5, 1, 2):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')
            projs = compute_proj_raw(raw, duration=ii - 0.1, stop=raw_time,
                                     n_grad=1, n_mag=1, n_eeg=0)
            assert_true(len(w) == 1)

        # test that you can compute the projection matrix
        projs = activate_proj(projs)
        proj, nproj, U = make_projector(projs, raw.ch_names, bads=[])

        assert_true(nproj == 2)
        assert_true(U.shape[1] == 2)

        # test that you can save them
        raw.info['projs'] += projs
        raw.save(op.join(tempdir, 'foo_%d_raw.fif' % ii), overwrite=True)

    # Test that purely continuous (no duration) raw projection works
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        projs = compute_proj_raw(raw, duration=None, stop=raw_time,
                                 n_grad=1, n_mag=1, n_eeg=0)
        assert_equal(len(w), 1)

    # test that you can compute the projection matrix
    projs = activate_proj(projs)
    proj, nproj, U = make_projector(projs, raw.ch_names, bads=[])

    assert_true(nproj == 2)
    assert_true(U.shape[1] == 2)

    # test that you can save them
    raw.info['projs'] += projs
    raw.save(op.join(tempdir, 'foo_rawproj_continuous_raw.fif'))

    # test resampled-data projector, upsampling instead of downsampling
    # here to save an extra filtering (raw would have to be LP'ed to be equiv)
    raw_resamp = cp.deepcopy(raw)
    raw_resamp.resample(raw.info['sfreq'] * 2, n_jobs=2, npad='auto')
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        projs = compute_proj_raw(raw_resamp, duration=None, stop=raw_time,
                                 n_grad=1, n_mag=1, n_eeg=0)
    projs = activate_proj(projs)
    proj_new, _, _ = make_projector(projs, raw.ch_names, bads=[])
    assert_array_almost_equal(proj_new, proj, 4)

    # test with bads
    raw.load_bad_channels(bads_fname)  # adds 2 bad mag channels
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        projs = compute_proj_raw(raw, n_grad=0, n_mag=0, n_eeg=1)

    # test that bad channels can be excluded
    proj, nproj, U = make_projector(projs, raw.ch_names,
                                    bads=raw.ch_names)
    assert_array_almost_equal(proj, np.eye(len(raw.ch_names)))


def test_make_eeg_average_ref_proj():
    """Test EEG average reference projection."""
    raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=True)
    eeg = mne.pick_types(raw.info, meg=False, eeg=True)

    # No average EEG reference
    assert_true(not np.all(raw._data[eeg].mean(axis=0) < 1e-19))

    # Apply average EEG reference
    car = make_eeg_average_ref_proj(raw.info)
    reref = raw.copy()
    reref.add_proj(car)
    reref.apply_proj()
    assert_array_almost_equal(reref._data[eeg].mean(axis=0), 0, decimal=19)

    # Error when custom reference has already been applied
    raw.info['custom_ref_applied'] = True
    assert_raises(RuntimeError, make_eeg_average_ref_proj, raw.info)


def test_has_eeg_average_ref_proj():
    """Test checking whether an EEG average reference exists"""
    assert_true(not _has_eeg_average_ref_proj([]))

    raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=False)
    raw.set_eeg_reference()
    assert_true(_has_eeg_average_ref_proj(raw.info['projs']))


def test_needs_eeg_average_ref_proj():
    """Test checking whether a recording needs an EEG average reference"""
    raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=False)
    assert_true(_needs_eeg_average_ref_proj(raw.info))

    raw.set_eeg_reference()
    assert_true(not _needs_eeg_average_ref_proj(raw.info))

    # No EEG channels
    raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=True)
    eeg = [raw.ch_names[c] for c in pick_types(raw.info, meg=False, eeg=True)]
    raw.drop_channels(eeg)
    assert_true(not _needs_eeg_average_ref_proj(raw.info))

    # Custom ref flag set
    raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=False)
    raw.info['custom_ref_applied'] = True
    assert_true(not _needs_eeg_average_ref_proj(raw.info))

run_tests_if_main()