
|
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_allclose,
assert_array_equal)
from scipy.signal import welch
import pytest
from mne.utils import catch_logging
from mne.time_frequency import psd_array_welch, psd_array_multitaper
from mne.time_frequency.multitaper import _psd_from_mt
from mne.time_frequency.psd import _median_biases
def test_psd_nan():
"""Test handling of NaN in psd_array_welch."""
n_samples, n_fft, n_overlap = 2048, 1024, 512
x = np.random.RandomState(0).randn(1, n_samples)
psds, freqs = psd_array_welch(x[:, :n_fft + n_overlap], float(n_fft),
n_fft=n_fft, n_overlap=n_overlap)
x[:, n_fft + n_overlap:] = np.nan # what Raw.get_data() will give us
psds_2, freqs_2 = psd_array_welch(x, float(n_fft), n_fft=n_fft,
n_overlap=n_overlap)
assert_allclose(freqs, freqs_2)
assert_allclose(psds, psds_2)
# 1-d
psds_2, freqs_2 = psd_array_welch(
x[0], float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
assert_allclose(freqs, freqs_2)
assert_allclose(psds[0], psds_2)
# defaults
with catch_logging() as log:
psd_array_welch(x, float(n_fft), verbose='debug')
log = log.getvalue()
assert 'using 256-point FFT on 256 samples with 0 overlap' in log
assert 'hamming window' in log
def _make_psd_data():
"""Make noise data with sinusoids in 2 out of 7 channels."""
rng = np.random.default_rng(0)
n_chan, n_times, sfreq = 7, 8000, 1000
data = 0.1 * rng.random((n_chan, n_times))
times = np.arange(n_times) / sfreq
sinusoid_freqs = [8., 50.]
chs_with_sinusoids = [0, 1]
for ix, freq in zip(chs_with_sinusoids, sinusoid_freqs):
data[ix, :] += 2 * np.sin(np.pi * 2. * freq * times)
return data, sfreq, sinusoid_freqs
@pytest.mark.parametrize(
'psd_func, psd_kwargs',
[(psd_array_welch, dict(n_fft=128, window='hann')),
(psd_array_multitaper, dict(low_bias=True))])
def test_psd(psd_func, psd_kwargs):
"""Tests the welch and multitaper PSD."""
data, sfreq, sinusoid_freqs = _make_psd_data()
# prepare kwargs
psd_kwargs.update(dict(fmin=2, fmax=70, verbose='debug'))
# compute PSD and test basic conformity
with catch_logging() as log:
psds, freqs = psd_func(data, sfreq, **psd_kwargs)
if psd_func is psd_array_welch:
log = log.getvalue()
n_fft = psd_kwargs['n_fft']
assert f'{n_fft}-point FFT on {n_fft} samples with 0 overl' in log
assert 'hann window' in log
assert psds.shape == (data.shape[0], len(freqs))
assert np.sum(freqs < 0) == 0
assert np.sum(psds < 0) == 0
# Is power found where it should be?
ixs_max = np.argmax(psds, axis=1)
for ixmax, ifreq in zip(ixs_max, sinusoid_freqs):
# Find nearest frequency to the "true" freq
ixtrue = np.argmin(np.abs(ifreq - freqs))
assert (np.abs(ixmax - ixtrue) < 2)
def test_psd_array_welch_nperseg_kwarg():
"""Test n_per_seg and padding in psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
kwargs = dict(fmin=2, fmax=70, n_per_seg=128)
# test n_per_seg in psd_array_welch (and padding)
psds1, freqs1 = psd_array_welch(data, sfreq, n_fft=128, **kwargs)
psds2, freqs2 = psd_array_welch(data, sfreq, n_fft=256, **kwargs)
assert len(freqs1) == np.floor(len(freqs2) / 2.)
assert psds1.shape[-1] == np.floor(psds2.shape[-1] / 2.)
# test bad n_fft
with pytest.raises(ValueError, match='n_fft is not allowed to be > n_tim'):
kwargs.update(n_per_seg=None)
bad_n_fft = int(data.shape[-1] * 1.1)
psd_array_welch(data, sfreq, n_fft=bad_n_fft, **kwargs)
# test bad n_overlap
with pytest.raises(ValueError, match='n_overlap cannot be greater'):
kwargs.update(n_per_seg=64)
psd_array_welch(data, sfreq, n_fft=128, n_overlap=90, **kwargs)
# test bad fmin/fmax
with pytest.raises(ValueError, match='No frequencies found'):
psd_array_welch(data, sfreq, fmin=10, fmax=1)
def test_complex_multitaper():
"""Test complex-valued multitaper output."""
data, sfreq, _ = _make_psd_data()
psd_complex, freq_complex, weights = psd_array_multitaper(
data[:4, :500], sfreq, output='complex')
psd, freq = psd_array_multitaper(data[:4, :500], sfreq, output='power')
assert_array_equal(freq_complex, freq)
assert psd_complex.ndim == 3 # channels x tapers x freqs
psd_from_complex = _psd_from_mt(psd_complex, weights)
assert_allclose(psd_from_complex, psd)
# Copied from SciPy
def _median_bias(n):
ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1)
return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2)
@pytest.mark.parametrize('crop', (False, True))
def test_psd_array_welch_average_kwarg(crop):
"""Test `average` kwarg of psd_array_welch()."""
data, sfreq, _ = _make_psd_data()
# prepare kwargs
n_per_seg = 32
kwargs = dict(fmin=0, fmax=np.inf, n_fft=64, n_per_seg=n_per_seg,
n_overlap=0)
# optionally crop data by n_per_seg so that we are sure to test both an
# odd number and an even number of estimates (for median bias)
if crop:
data = data[..., :-n_per_seg]
# run with average=mean/median/None
psds_mean, freqs_mean = psd_array_welch(
data, sfreq, average='mean', **kwargs)
psds_median, freqs_median = psd_array_welch(
data, sfreq, average='median', **kwargs)
psds_unagg, freqs_unagg = psd_array_welch(
data, sfreq, average=None, **kwargs)
# Frequencies should be equal across all "average" types, as we feed in
# the exact same data.
assert_array_equal(freqs_mean, freqs_median)
assert_array_equal(freqs_mean, freqs_unagg)
# For `average=None`, the last dimension contains the un-aggregated
# segments.
assert psds_mean.shape == psds_median.shape
assert psds_mean.shape == psds_unagg.shape[:-1]
assert_array_equal(psds_mean, psds_unagg.mean(axis=-1))
# Compare with manual median calculation (_median_bias copied from SciPy)
bias = _median_bias(psds_unagg.shape[-1])
assert_allclose(psds_median, np.median(psds_unagg, axis=-1) / bias)
# check shape of unagg
n_chan, n_times = data.shape
n_freq = len(freqs_unagg)
n_segs = np.ceil(n_times / n_per_seg).astype(int)
assert n_segs % 2 == (1 if crop else 0)
assert psds_unagg.shape == (n_chan, n_freq, n_segs)
@pytest.mark.parametrize('n', (2, 3, 5, 8, 12, 13, 14, 15))
def test_median_biases(n):
"""Test vectorization of median_biases."""
want_biases = np.concatenate(
([1., 1.], [_median_bias(ii) for ii in range(2, n + 1)]))
got_biases = _median_biases(n)
assert_allclose(want_biases, got_biases)
assert_allclose(got_biases[n], _median_bias(n))
assert_allclose(got_biases[:3], 1.)
@pytest.mark.slowtest
def test_compares_psd():
"""Test PSD estimation on raw for plt.psd and scipy.signal.welch."""
data, sfreq, _ = _make_psd_data()
n_fft = 2048
fmin, fmax = 2, 70
# Compute PSD with psd_array_welch
psds_mne, freqs_mne = psd_array_welch(
data, sfreq, fmin=fmin, fmax=fmax, n_fft=n_fft)
# Compute psds with scipy.signal.welch
freqs_scipy, psds_scipy = welch(
data, fs=sfreq, nperseg=n_fft, noverlap=0, window='hamming')
# restrict to the relevant frequencies
mask = (freqs_scipy >= fmin) & (freqs_scipy <= fmax)
freqs_scipy = freqs_scipy[mask]
psds_scipy = psds_scipy[:, mask]
# make sure they match
assert_array_almost_equal(psds_mne, psds_scipy)
assert_array_equal(freqs_mne, freqs_scipy)
assert (psds_mne.shape == (data.shape[0], len(freqs_mne)))
assert (psds_scipy.shape == (data.shape[0], len(freqs_scipy)))
assert (np.sum(freqs_mne < 0) == 0)
assert (np.sum(freqs_scipy < 0) == 0)
assert (np.sum(psds_mne < 0) == 0)
assert (np.sum(psds_scipy < 0) == 0)
def test_psd_array_welch_n_jobs():
"""Test that n_jobs works even with more jobs than channels."""
data = np.zeros((1, 2048))
psd_array_welch(data, 1024, n_jobs=1)
psd_array_welch(data, 1024, n_jobs=2)
|