File: test_stockwell.py

package info (click to toggle)
python-mne 1.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 131,492 kB
  • sloc: python: 213,302; javascript: 12,910; sh: 447; makefile: 144
file content (152 lines) | stat: -rw-r--r-- 5,313 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from pathlib import Path

import numpy as np
import pytest
from numpy.testing import (
    assert_allclose,
    assert_array_almost_equal,
    assert_array_less,
    assert_equal,
)
from scipy import fftpack

from mne import Epochs, make_fixed_length_events, read_events
from mne.io import read_raw_fif
from mne.time_frequency import AverageTFR, tfr_array_stockwell
from mne.time_frequency._stockwell import (
    _check_input_st,
    _precompute_st_windows,
    _st,
    _st_power_itc,
    tfr_stockwell,
)
from mne.utils import _record_warnings

base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
raw_ctf_fname = base_dir / "test_ctf_raw.fif"


def test_stockwell_ctf():
    """Test that Stockwell can be calculated on CTF data."""
    raw = read_raw_fif(raw_ctf_fname)
    raw.apply_gradient_compensation(3)
    events = make_fixed_length_events(raw, duration=0.5)
    evoked = Epochs(
        raw, events, tmin=-0.2, tmax=0.3, decim=10, preload=True, verbose="error"
    ).average()
    tfr_stockwell(evoked, verbose="error")  # smoke test


def test_stockwell_check_input():
    """Test input checker for stockwell."""
    # check for data size equal and unequal to a power of 2

    for last_dim in (127, 128):
        data = np.zeros((2, 10, last_dim))
        with _record_warnings():  # n_fft sometimes
            x_in, n_fft, zero_pad = _check_input_st(data, None)

        assert_equal(x_in.shape, (2, 10, 128))
        assert_equal(n_fft, 128)
        assert_equal(zero_pad, 128 - last_dim)


def test_stockwell_st_no_zero_pad():
    """Test stockwell power itc."""
    data = np.zeros((20, 128))
    start_f = 1
    stop_f = 10
    sfreq = 30
    width = 2
    W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
    _st_power_itc(data, 10, True, 0, 1, W)


def test_stockwell_core():
    """Test stockwell transform."""
    # adapted from
    # http://vcs.ynic.york.ac.uk/docs/naf/intro/concepts/timefreq.html
    sfreq = 1000.0  # make things easy to understand
    dur = 0.5
    onset, offset = 0.175, 0.275
    n_samp = int(sfreq * dur)
    t = np.arange(n_samp) / sfreq  # make an array for time
    pulse_freq = 15.0
    pulse = np.cos(2.0 * np.pi * pulse_freq * t)
    pulse[0 : int(onset * sfreq)] = 0.0  # Zero before our desired pulse
    pulse[int(offset * sfreq) :] = 0.0  # and zero after our desired pulse

    width = 0.5
    freqs = fftpack.fftfreq(len(pulse), 1.0 / sfreq)
    fmin, fmax = 1.0, 100.0
    start_f, stop_f = (np.abs(freqs - f).argmin() for f in (fmin, fmax))
    W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width)

    st_pulse = _st(pulse, start_f, W)
    st_pulse = np.abs(st_pulse) ** 2
    assert_equal(st_pulse.shape[-1], len(pulse))
    st_max_freq = freqs[st_pulse.max(axis=1).argmax(axis=0)]  # max freq
    assert_allclose(st_max_freq, pulse_freq, atol=1.0)
    assert onset < t[st_pulse.max(axis=0).argmax(axis=0)] < offset

    # test inversion to FFT, by averaging local spectra, see eq. 5 in
    # Moukadem, A., Bouguila, Z., Ould Abdeslam, D. and Alain Dieterlen.
    # "Stockwell transform optimization applied on the detection of split in
    # heart sounds."

    width = 1.0
    start_f, stop_f = 0, len(pulse)
    W = _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width)
    y = _st(pulse, start_f, W)
    # invert stockwell
    y_inv = fftpack.ifft(np.sum(y, axis=1)).real
    assert_array_almost_equal(pulse, y_inv)


def test_stockwell_api():
    """Test stockwell functions."""
    raw = read_raw_fif(raw_fname)
    event_id, tmin, tmax = 1, -0.2, 0.5
    event_name = base_dir / "test-eve.fif"
    events = read_events(event_name)
    epochs = Epochs(
        raw,
        events,  # XXX pick 2 has epochs of zeros.
        event_id,
        tmin,
        tmax,
        picks=[0, 1, 3],
    )
    for fmin, fmax in [(None, 50), (5, 50), (5, None)]:
        power, itc = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, return_itc=True)
        if fmax is not None:
            assert power.freqs.max() <= fmax
        power_evoked = tfr_stockwell(
            epochs.average(), fmin=fmin, fmax=fmax, return_itc=False
        )
        # for multitaper these don't necessarily match, but they seem to
        # for stockwell... if this fails, this maybe could be changed
        # just to check the shape
        assert_array_almost_equal(power_evoked.data, power.data)
    assert isinstance(power, AverageTFR)
    assert isinstance(itc, AverageTFR)
    assert_equal(power.data.shape, itc.data.shape)
    assert itc.data.min() >= 0.0
    assert itc.data.max() <= 1.0
    assert np.log(power.data.max()) * 20 <= 0.0
    assert np.log(power.data.max()) * 20 <= 0.0
    with pytest.raises(TypeError, match="ndarray"):
        tfr_array_stockwell("foo", 1000.0)
    data = np.random.RandomState(0).randn(1, 1024)
    with pytest.raises(ValueError, match="3D with shape"):
        tfr_array_stockwell(data, 1000.0)
    data = data[np.newaxis]
    power, itc, freqs = tfr_array_stockwell(data, 1000.0, return_itc=True)
    assert_allclose(itc, np.ones_like(itc))
    assert power.shape == (1, len(freqs), data.shape[-1])
    assert_array_less(0, power)