File: test_ctps.py

package info (click to toggle)
python-mne 1.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 131,492 kB
  • sloc: python: 213,302; javascript: 12,910; sh: 447; makefile: 144
file content (88 lines) | stat: -rw-r--r-- 2,926 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from mne.preprocessing.ctps_ import _compute_normalized_phase, _prob_kuiper, ctps
from mne.time_frequency import morlet

###############################################################################
# Generate testing signal

tmin = -0.3
sfreq = 1000.0  # Hz
tstep = 1.0 / sfreq
n_samples = 600
times = np.linspace(tmin, tmin + n_samples * tstep, n_samples)

# Generate times series from Morlet wavelet
single_trial = np.zeros((1, len(times)))
Ws = morlet(sfreq, [3], n_cycles=[1])

single_trial[0][: len(Ws[0])] = np.real(Ws[0])
roll_to = 300 - 265  # shift data to center of time window
single_trial = np.roll(single_trial, roll_to)
rng = np.random.RandomState(42)


def get_data(n_trials, j_extent):
    """Generate ground truth and testing data."""
    ground_truth = np.tile(single_trial, n_trials)
    my_shape = n_trials, 1, 600
    random_data = rng.random_sample(my_shape)
    rand_ints = rng.randint(-j_extent, j_extent, n_trials)
    jittered_data = np.array([np.roll(single_trial, i) for i in rand_ints])
    data = np.concatenate(
        [
            ground_truth.reshape(my_shape),
            jittered_data.reshape(my_shape),
            random_data.reshape(my_shape),
        ],
        1,
    )

    assert data.shape == (n_trials, 3, 600)
    return data


# vary extent of jittering --> creates phaselocks at the borders if
# 2 * extent != n_samples
iter_test_ctps = enumerate(zip([400, 400], [150, 300], [0.6, 0.2]))


def test_ctps():
    """Test basic ctps functionality."""
    for ii, (n_trials, j_extent, pk_max) in iter_test_ctps:
        data = get_data(n_trials, j_extent)
        ks_dyn, pk_dyn, phase_trial = ctps(data)
        data2 = _compute_normalized_phase(data)
        ks_dyn2, pk_dyn2, _ = ctps(data2, is_raw=False)
        for a, b in zip([ks_dyn, pk_dyn, phase_trial], [ks_dyn2, pk_dyn2, data2]):
            assert_array_equal(a, b)
            assert a.min() >= 0
            assert a.max() <= 1
            assert b.min() >= 0
            assert b.max() <= 1

        # test for normalization
        assert (pk_dyn.min() > 0.0) or (pk_dyn.max() < 1.0)
        # test shapes
        assert phase_trial.shape == data.shape
        assert pk_dyn.shape == data.shape[1:]
        # tets ground_truth + random + jittered case
        assert pk_dyn[0].max() == 1.0
        assert len(np.unique(pk_dyn[0])) == 1.0
        assert pk_dyn[1].max() < pk_max
        assert pk_dyn[2].max() > 0.3
        if ii < 1:
            pytest.raises(ValueError, ctps, data[:, :, :, None])

    assert _prob_kuiper(1.0, 400) == 1.0
    # test vecrosization
    assert_array_equal(
        _prob_kuiper(np.array([1.0, 1.0]), 400), _prob_kuiper(np.array([1.0, 1.0]), 400)
    )
    assert _prob_kuiper(0.1, 400) < 0.1