1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
import os.path as op
import pytest
import mne
from mne.datasets import testing
from mne import read_forward_solution
from mne.minimum_norm import (read_inverse_operator,
point_spread_function, cross_talk_function)
from mne.utils import run_tests_if_main
data_path = op.join(testing.data_path(download=False), 'MEG', 'sample')
fname_inv_meg = op.join(data_path,
'sample_audvis_trunc-meg-eeg-oct-4-meg-inv.fif')
fname_inv_meeg = op.join(data_path, 'sample_audvis_trunc-meg-eeg-oct-4-'
'meg-eeg-diagnoise-inv.fif')
fname_fwd = op.join(data_path, 'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif')
fname_label = [op.join(data_path, 'labels', 'Aud-rh.label'),
op.join(data_path, 'labels', 'Aud-lh.label')]
snr = 3.0
lambda2 = 1.0 / snr ** 2
@pytest.mark.slowtest
@testing.requires_testing_data
def test_psf_ctf():
"""Test computation of PSFs and CTFs for linear estimators."""
forward = read_forward_solution(fname_fwd)
labels = [mne.read_label(ss) for ss in fname_label]
method = 'MNE'
n_svd_comp = 2
# make sure it works for both types of inverses
for fname_inv in (fname_inv_meg, fname_inv_meeg):
inverse_operator = read_inverse_operator(fname_inv)
# Test PSFs (then CTFs)
for mode in ('sum', 'svd'):
stc_psf, psf_ev = point_spread_function(
inverse_operator, forward, method=method, labels=labels,
lambda2=lambda2, pick_ori='normal', mode=mode,
n_svd_comp=n_svd_comp, use_cps=True)
n_vert, n_samples = stc_psf.shape
should_n_vert = (inverse_operator['src'][1]['vertno'].shape[0] +
inverse_operator['src'][0]['vertno'].shape[0])
if mode == 'svd':
should_n_samples = len(labels) * n_svd_comp + 1
else:
should_n_samples = len(labels) + 1
assert (n_vert == should_n_vert)
assert (n_samples == should_n_samples)
n_chan, n_samples = psf_ev.data.shape
assert (n_chan == forward['nchan'])
# Test CTFs
for mode in ('sum', 'svd'):
stc_ctf = cross_talk_function(
inverse_operator, forward, labels, method=method,
lambda2=lambda2, signed=False, mode=mode,
n_svd_comp=n_svd_comp, use_cps=True)
n_vert, n_samples = stc_ctf.shape
should_n_vert = (inverse_operator['src'][1]['vertno'].shape[0] +
inverse_operator['src'][0]['vertno'].shape[0])
if mode == 'svd':
should_n_samples = len(labels) * n_svd_comp + 1
else:
should_n_samples = len(labels) + 1
assert (n_vert == should_n_vert)
assert (n_samples == should_n_samples)
run_tests_if_main()
|