1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
import numpy as np
from numpy.testing import assert_array_almost_equal
from nose.tools import assert_true, assert_raises
from mne.fixes import tril_indices
from mne.connectivity import spectral_connectivity
from mne.connectivity.spectral import _CohEst
from mne import SourceEstimate
from mne.filter import band_pass_filter
def _stc_gen(data, sfreq, tmin, combo=False):
"""Simulate a SourceEstimate generator"""
vertices = [np.arange(data.shape[1]), np.empty(0)]
for d in data:
if not combo:
stc = SourceEstimate(data=d, vertices=vertices,
tmin=tmin, tstep=1 / float(sfreq))
yield stc
else:
# simulate a combination of array and source estimate
arr = d[0]
stc = SourceEstimate(data=d[1:], vertices=vertices,
tmin=tmin, tstep=1 / float(sfreq))
yield (arr, stc)
def test_spectral_connectivity():
"""Test frequency-domain connectivity methods"""
# Use a case known to have no spurious correlations (it would bad if
# nosetests could randomly fail):
np.random.seed(0)
sfreq = 50.
n_signals = 3
n_epochs = 10
n_times = 500
tmin = 0.
tmax = (n_times - 1) / sfreq
data = np.random.randn(n_epochs, n_signals, n_times)
times_data = np.linspace(tmin, tmax, n_times)
# simulate connectivity from 5Hz..15Hz
fstart, fend = 5.0, 15.0
for i in range(n_epochs):
data[i, 1, :] = band_pass_filter(data[i, 0, :], sfreq, fstart, fend)
# add some noise, so the spectrum is not exactly zero
data[i, 1, :] += 1e-2 * np.random.randn(n_times)
# First we test some invalid parameters:
assert_raises(ValueError, spectral_connectivity, data, method='notamethod')
assert_raises(ValueError, spectral_connectivity, data,
mode='notamode')
# test invalid fmin fmax settings
assert_raises(ValueError, spectral_connectivity, data, fmin=10,
fmax=10 + 0.5 * (sfreq / float(n_times)))
assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=5)
assert_raises(ValueError, spectral_connectivity, data, fmin=(0, 11),
fmax=(5, 10))
assert_raises(ValueError, spectral_connectivity, data, fmin=(11,),
fmax=(12, 15))
methods = ['coh', 'imcoh', 'cohy', 'plv', 'ppc', 'pli', 'pli2_unbiased',
'wpli', 'wpli2_debiased', 'coh']
modes = ['multitaper', 'fourier', 'cwt_morlet']
# define some frequencies for cwt
cwt_frequencies = np.arange(3, 24.5, 1)
for mode in modes:
for method in methods:
if method == 'coh' and mode == 'multitaper':
# only check adaptive estimation for coh to reduce test time
check_adaptive = [False, True]
else:
check_adaptive = [False]
if method == 'coh' and mode == 'cwt_morlet':
# so we also test using an array for num cycles
cwt_n_cycles = 7. * np.ones(len(cwt_frequencies))
else:
cwt_n_cycles = 7.
for adaptive in check_adaptive:
if adaptive:
mt_bandwidth = 1.
else:
mt_bandwidth = None
con, freqs, times, n, _ = spectral_connectivity(data,
method=method, mode=mode,
indices=None, sfreq=sfreq, mt_adaptive=adaptive,
mt_low_bias=True, mt_bandwidth=mt_bandwidth,
cwt_frequencies=cwt_frequencies,
cwt_n_cycles=cwt_n_cycles)
assert_true(n == n_epochs)
assert_array_almost_equal(times_data, times)
if mode == 'multitaper':
upper_t = 0.95
lower_t = 0.5
else:
# other estimates have higher variance
upper_t = 0.8
lower_t = 0.75
# test the simulated signal
if method == 'coh':
idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
# we see something for zero-lag
assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t))
if mode != 'cwt_morlet':
idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
assert_true(np.all(con[1, 0, idx[1]:] < lower_t))
elif method == 'cohy':
idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
# imaginary coh will be zero
assert_true(np.all(np.imag(con[1, 0, idx[0]:idx[1]])
< lower_t))
# we see something for zero-lag
assert_true(np.all(np.abs(con[1, 0, idx[0]:idx[1]])
> upper_t))
idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
if mode != 'cwt_morlet':
assert_true(np.all(np.abs(con[1, 0, :idx[0]])
< lower_t))
assert_true(np.all(np.abs(con[1, 0, idx[1]:])
< lower_t))
elif method == 'imcoh':
idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
# imaginary coh will be zero
assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t))
idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
assert_true(np.all(con[1, 0, idx[1]:] < lower_t))
# compute same connections using indices and 2 jobs,
# also add a second method
indices = tril_indices(n_signals, -1)
test_methods = (method, _CohEst)
combo = True if method == 'coh' else False
stc_data = _stc_gen(data, sfreq, tmin)
con2, freqs2, times2, n2, _ = spectral_connectivity(stc_data,
method=test_methods, mode=mode, indices=indices,
sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True,
mt_bandwidth=mt_bandwidth, tmin=tmin, tmax=tmax,
cwt_frequencies=cwt_frequencies,
cwt_n_cycles=cwt_n_cycles, n_jobs=2)
assert_true(isinstance(con2, list))
assert_true(len(con2) == 2)
if method == 'coh':
assert_array_almost_equal(con2[0], con2[1])
con2 = con2[0] # only keep the first method
# we get the same result for the probed connections
assert_array_almost_equal(freqs, freqs2)
assert_array_almost_equal(con[indices], con2)
assert_true(n == n2)
assert_array_almost_equal(times_data, times2)
# compute same connections for two bands, fskip=1, and f. avg.
fmin = (5., 15.)
fmax = (15., 30.)
con3, freqs3, times3, n3, _ = spectral_connectivity(data,
method=method, mode=mode,
indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax,
fskip=1, faverage=True, mt_adaptive=adaptive,
mt_low_bias=True, mt_bandwidth=mt_bandwidth,
cwt_frequencies=cwt_frequencies,
cwt_n_cycles=cwt_n_cycles)
assert_true(isinstance(freqs3, list))
assert_true(len(freqs3) == len(fmin))
for i in range(len(freqs3)):
assert_true(np.all((freqs3[i] >= fmin[i])
& (freqs3[i] <= fmax[i])))
# average con2 "manually" and we get the same result
for i in range(len(freqs3)):
freq_idx = np.searchsorted(freqs2, freqs3[i])
con2_avg = np.mean(con2[:, freq_idx], axis=1)
assert_array_almost_equal(con2_avg, con3[:, i])
|