1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
|
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_allclose,
assert_array_equal)
from scipy.signal import welch
import pytest
from mne.utils import catch_logging
from mne.time_frequency import psd_array_welch, psd_array_multitaper
from mne.time_frequency.multitaper import _psd_from_mt
from mne.time_frequency.psd import _median_biases
def test_psd_nan():
"""Test handling of NaN in psd_array_welch."""
n_samples, n_fft, n_overlap = 2048, 1024, 512
x = np.random.RandomState(0).randn(1, n_samples)
psds, freqs = psd_array_welch(x[:, :n_fft + n_overlap], float(n_fft),
n_fft=n_fft, n_overlap=n_overlap)
x[:, n_fft + n_overlap:] = np.nan # what Raw.get_data() will give us
psds_2, freqs_2 = psd_array_welch(x, float(n_fft), n_fft=n_fft,
n_overlap=n_overlap)
assert_allclose(freqs, freqs_2)
assert_allclose(psds, psds_2)
# 1-d
psds_2, freqs_2 = psd_array_welch(
x[0], float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
assert_allclose(freqs, freqs_2)
assert_allclose(psds[0], psds_2)
# defaults
with catch_logging() as log:
psd_array_welch(x, float(n_fft), verbose='debug')
log = log.getvalue()
assert 'using 256-point FFT on 256 samples with 0 overlap' in log
assert 'hamming window' in log
def _make_psd_data():
"""Make noise data with sinusoids in 2 out of 7 channels."""
rng = np.random.default_rng(0)
n_chan, n_times, sfreq = 7, 8000, 1000
data = 0.1 * rng.random((n_chan, n_times))
times = np.arange(n_times) / sfreq
sinusoid_freqs = [8., 50.]
chs_with_sinusoids = [0, 1]
for ix, freq in zip(chs_with_sinusoids, sinusoid_freqs):
data[ix, :] += 2 * np.sin(np.pi * 2. * freq * times)
return data, sfreq, sinusoid_freqs
@pytest.mark.parametrize(
'psd_func, psd_kwargs',
[(psd_array_welch, dict(n_fft=128, window='hann')),
(psd_array_multitaper, dict(low_bias=True))])
def test_psd(psd_func, psd_kwargs):
"""Tests the welch and multitaper PSD."""
data, sfreq, sinusoid_freqs = _make_psd_data()
# prepare kwargs
psd_kwargs.update(dict(fmin=2, fmax=70, verbose='debug'))
# compute PSD and test basic conformity
with catch_logging() as log:
psds, freqs = psd_func(data, sfreq, **psd_kwargs)
if psd_func is psd_array_welch:
log = log.getvalue()
n_fft = psd_kwargs['n_fft']
assert f'{n_fft}-point FFT on {n_fft} samples with 0 overl' in log
assert 'hann window' in log
assert psds.shape == (data.shape[0], len(freqs))
assert np.sum(freqs < 0) == 0
assert np.sum(psds < 0) == 0
# Is power found where it should be?
ixs_max = np.argmax(psds, axis=1)
for ixmax, ifreq in zip(ixs_max, sinusoid_freqs):
# Find nearest frequency to the "true" freq
ixtrue = np.argmin(np.abs(ifreq - freqs))
assert (np.abs(ixmax - ixtrue) < 2)
def test_psd_array_welch_nperseg_kwarg():
"""Test n_per_seg and padding in psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
kwargs = dict(fmin=2, fmax=70, n_per_seg=128)
# test n_per_seg in psd_array_welch (and padding)
psds1, freqs1 = psd_array_welch(data, sfreq, n_fft=128, **kwargs)
psds2, freqs2 = psd_array_welch(data, sfreq, n_fft=256, **kwargs)
assert len(freqs1) == np.floor(len(freqs2) / 2.)
assert psds1.shape[-1] == np.floor(psds2.shape[-1] / 2.)
# test bad n_fft
with pytest.raises(ValueError, match='n_fft is not allowed to be > n_tim'):
kwargs.update(n_per_seg=None)
bad_n_fft = int(data.shape[-1] * 1.1)
psd_array_welch(data, sfreq, n_fft=bad_n_fft, **kwargs)
# test bad n_overlap
with pytest.raises(ValueError, match='n_overlap cannot be greater'):
kwargs.update(n_per_seg=64)
psd_array_welch(data, sfreq, n_fft=128, n_overlap=90, **kwargs)
# test bad fmin/fmax
with pytest.raises(ValueError, match='No frequencies found'):
psd_array_welch(data, sfreq, fmin=10, fmax=1)
def test_complex_multitaper():
"""Test complex-valued multitaper output."""
data, sfreq, _ = _make_psd_data()
psd_complex, freq_complex, weights = psd_array_multitaper(
data[:4, :500], sfreq, output='complex')
psd, freq = psd_array_multitaper(data[:4, :500], sfreq, output='power')
assert_array_equal(freq_complex, freq)
assert psd_complex.ndim == 3 # channels x tapers x freqs
psd_from_complex = _psd_from_mt(psd_complex, weights)
assert_allclose(psd_from_complex, psd)
# Copied from SciPy
def _median_bias(n):
ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1)
return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2)
@pytest.mark.parametrize('crop', (False, True))
def test_psd_array_welch_average_kwarg(crop):
"""Test `average` kwarg of psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
n_per_seg = 32
kwargs = dict(fmin=0, fmax=np.inf, n_fft=64, n_per_seg=n_per_seg,
n_overlap=0)
# optionally crop data by n_per_seg so that we are sure to test both an
# odd number and an even number of estimates (for median bias)
if crop:
data = data[..., :-n_per_seg]
# run with average=mean/median/None
psds_mean, freqs_mean = psd_array_welch(
data, sfreq, average='mean', **kwargs)
psds_median, freqs_median = psd_array_welch(
data, sfreq, average='median', **kwargs)
psds_unagg, freqs_unagg = psd_array_welch(
data, sfreq, average=None, **kwargs)
# Frequencies should be equal across all "average" types, as we feed in
# the exact same data.
assert_array_equal(freqs_mean, freqs_median)
assert_array_equal(freqs_mean, freqs_unagg)
# For `average=None`, the last dimension contains the un-aggregated
# segments.
assert psds_mean.shape == psds_median.shape
assert psds_mean.shape == psds_unagg.shape[:-1]
assert_array_equal(psds_mean, psds_unagg.mean(axis=-1))
# Compare with manual median calculation (_median_bias copied from SciPy)
bias = _median_bias(psds_unagg.shape[-1])
assert_allclose(psds_median, np.median(psds_unagg, axis=-1) / bias)
# check shape of unagg
n_chan, n_times = data.shape
n_freq = len(freqs_unagg)
n_segs = np.ceil(n_times / n_per_seg).astype(int)
assert n_segs % 2 == (1 if crop else 0)
assert psds_unagg.shape == (n_chan, n_freq, n_segs)
@pytest.mark.parametrize('n', (2, 3, 5, 8, 12, 13, 14, 15))
def test_median_biases(n):
"""Test vectorization of median_biases."""
want_biases = np.concatenate(
([1., 1.], [_median_bias(ii) for ii in range(2, n + 1)]))
got_biases = _median_biases(n)
assert_allclose(want_biases, got_biases)
assert_allclose(got_biases[n], _median_bias(n))
assert_allclose(got_biases[:3], 1.)
@pytest.mark.slowtest
def test_compares_psd():
"""Test PSD estimation on raw for plt.psd and scipy.signal.welch."""
data, sfreq, _ = _make_psd_data()
n_fft = 2048
fmin, fmax = 2, 70
# Compute PSD with psd_array_welch
psds_mne, freqs_mne = psd_array_welch(
data, sfreq, fmin=fmin, fmax=fmax, n_fft=n_fft)
# Compute psds with scipy.signal.welch
freqs_scipy, psds_scipy = welch(
data, fs=sfreq, nperseg=n_fft, noverlap=0, window='hamming')
# restrict to the relevant frequencies
mask = (freqs_scipy >= fmin) & (freqs_scipy <= fmax)
freqs_scipy = freqs_scipy[mask]
psds_scipy = psds_scipy[:, mask]
# make sure they match
assert_array_almost_equal(psds_mne, psds_scipy)
assert_array_equal(freqs_mne, freqs_scipy)
assert (psds_mne.shape == (data.shape[0], len(freqs_mne)))
assert (psds_scipy.shape == (data.shape[0], len(freqs_scipy)))
assert (np.sum(freqs_mne < 0) == 0)
assert (np.sum(freqs_scipy < 0) == 0)
assert (np.sum(psds_mne < 0) == 0)
assert (np.sum(psds_scipy < 0) == 0)
def test_psd_array_welch_n_jobs():
"""Test that n_jobs works even with more jobs than channels."""
data = np.zeros((1, 2048))
psd_array_welch(data, 1024, n_jobs=1)
psd_array_welch(data, 1024, n_jobs=2)
|