File: test_effective.py

package info (click to toggle)
python-mne 0.17%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 95,104 kB
  • sloc: python: 110,639; makefile: 222; sh: 15
file content (39 lines) | stat: -rw-r--r-- 1,238 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from numpy.testing import assert_array_almost_equal

from mne.connectivity import phase_slope_index


def test_psi():
    """Test Phase Slope Index (PSI) estimation."""
    sfreq = 50.
    n_signals = 3
    n_epochs = 10
    n_times = 500
    rng = np.random.RandomState(42)
    data = rng.randn(n_epochs, n_signals, n_times)

    # simulate time shifts
    for i in range(n_epochs):
        data[i, 1, 10:] = data[i, 0, :-10]  # signal 0 is ahead
        data[i, 2, :-10] = data[i, 0, 10:]  # signal 2 is ahead

    psi, freqs, times, n_epochs, n_tapers = phase_slope_index(
        data, mode='fourier', sfreq=sfreq)
    assert psi[1, 0, 0] < 0
    assert psi[2, 0, 0] > 0

    indices = (np.array([0]), np.array([1]))
    psi_2, freqs, times, n_epochs, n_tapers = phase_slope_index(
        data, mode='fourier', sfreq=sfreq, indices=indices)

    # the measure is symmetric (sign flip)
    assert_array_almost_equal(psi_2[0, 0], -psi[1, 0, 0])

    cwt_freqs = np.arange(5., 20, 0.5)
    psi_cwt, freqs, times, n_epochs, n_tapers = phase_slope_index(
        data, mode='cwt_morlet', sfreq=sfreq, cwt_freqs=cwt_freqs,
        indices=indices)

    assert np.all(psi_cwt > 0)
    assert psi_cwt.shape[-1] == n_times