File: test_time_frequency.py

package info (click to toggle)
python-mne 1.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 131,492 kB
  • sloc: python: 213,302; javascript: 12,910; sh: 447; makefile: 144
file content (49 lines) | stat: -rw-r--r-- 1,349 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.


import numpy as np
import pytest
from numpy.testing import assert_array_equal

pytest.importorskip("sklearn")

from sklearn.base import clone

from mne.decoding.time_frequency import TimeFrequency


def test_timefrequency():
    """Test TimeFrequency."""
    # Init
    n_freqs = 3
    freqs = [20, 21, 22]
    tf = TimeFrequency(freqs, sfreq=100)
    for output in ["avg_power", "foo", None]:
        pytest.raises(ValueError, TimeFrequency, freqs, output=output)
    tf = clone(tf)

    # Clone estimator
    freqs_array = np.array(np.asarray(freqs))
    tf = TimeFrequency(freqs_array, 100, "morlet", freqs_array / 5.0)
    clone(tf)

    # Fit
    n_epochs, n_chans, n_times = 10, 2, 100
    X = np.random.rand(n_epochs, n_chans, n_times)
    tf.fit(X, None)

    # Transform
    tf = TimeFrequency(freqs, sfreq=100)
    tf.fit_transform(X, None)
    # 3-D X
    Xt = tf.transform(X)
    assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times])
    # 2-D X
    Xt = tf.transform(X[:, 0, :])
    assert_array_equal(Xt.shape, [n_epochs, n_freqs, n_times])
    # 3-D with decim
    tf = TimeFrequency(freqs, sfreq=100, decim=2)
    Xt = tf.transform(X)
    assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times // 2])