File: test_stim.py

package info (click to toggle)
python-mne 1.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 131,492 kB
  • sloc: python: 213,302; javascript: 12,910; sh: 447; makefile: 144
file content (154 lines) | stat: -rw-r--r-- 5,341 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from pathlib import Path

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal

from mne.epochs import Epochs
from mne.event import read_events
from mne.io import read_raw_fif
from mne.preprocessing.stim import fix_stim_artifact

data_path = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = data_path / "test_raw.fif"
event_fname = data_path / "test-eve.fif"


def test_fix_stim_artifact():
    """Test fix stim artifact."""
    events = read_events(event_fname)

    raw = read_raw_fif(raw_fname)
    pytest.raises(RuntimeError, fix_stim_artifact, raw)

    raw = read_raw_fif(raw_fname, preload=True)

    # use window before stimulus in epochs
    tmin, tmax, event_id = -0.2, 0.5, 1
    picks = ("meg", "eeg", "eog")
    epochs = Epochs(
        raw, events, event_id, tmin, tmax, picks=picks, preload=True, reject=None
    )
    e_start = int(np.ceil(epochs.info["sfreq"] * epochs.tmin))
    tmin, tmax = -0.045, -0.015
    tmin_samp = int(-0.035 * epochs.info["sfreq"]) - e_start
    tmax_samp = int(-0.015 * epochs.info["sfreq"]) - e_start

    epochs = fix_stim_artifact(
        epochs, tmin=tmin, tmax=tmax, mode="linear", picks=("eeg", "eog")
    )
    data = epochs.get_data(("eeg", "eog"))[:, :, tmin_samp:tmax_samp]
    diff_data0 = np.diff(data[0][0])
    diff_data0 -= np.mean(diff_data0)
    assert_array_almost_equal(diff_data0, np.zeros(len(diff_data0)))

    data = epochs.get_data("meg")[:, :, tmin_samp:tmax_samp]
    diff_data0 = np.diff(data[0][0])
    diff_data0 -= np.mean(diff_data0)
    assert np.all(diff_data0 != 0)

    epochs = fix_stim_artifact(epochs, tmin=tmin, tmax=tmax, mode="window")
    data_from_epochs_fix = epochs.get_data(copy=False)[:, :, tmin_samp:tmax_samp]
    assert not np.all(data_from_epochs_fix != 0)

    baseline = (-0.1, -0.05)
    epochs = fix_stim_artifact(
        epochs, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
    )
    b_start = int(np.ceil(epochs.info["sfreq"] * baseline[0]))
    b_end = int(np.ceil(epochs.info["sfreq"] * baseline[1]))
    base_t1 = b_start - e_start
    base_t2 = b_end - e_start
    baseline_mean = epochs.get_data()[:, :, base_t1:base_t2].mean(axis=2)[0][0]
    data = epochs.get_data()[:, :, tmin_samp:tmax_samp]
    assert data[0][0][0] == baseline_mean

    # use window before stimulus in raw
    event_idx = np.where(events[:, 2] == 1)[0][0]
    tmin, tmax = -0.045, -0.015
    tmin_samp = int(-0.035 * raw.info["sfreq"])
    tmax_samp = int(-0.015 * raw.info["sfreq"])
    tidx = int(events[event_idx, 0] - raw.first_samp)

    pytest.raises(ValueError, fix_stim_artifact, raw, events=np.array([]))
    raw = fix_stim_artifact(
        raw,
        events=None,
        event_id=1,
        tmin=tmin,
        tmax=tmax,
        mode="linear",
        stim_channel="STI 014",
    )
    data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
    diff_data0 = np.diff(data[0])
    diff_data0 -= np.mean(diff_data0)
    assert_array_almost_equal(diff_data0, np.zeros(len(diff_data0)))

    raw = fix_stim_artifact(
        raw, events, event_id=1, tmin=tmin, tmax=tmax, mode="window"
    )
    data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]

    assert np.all(data) == 0.0

    raw = fix_stim_artifact(
        raw,
        events,
        event_id=1,
        tmin=tmin,
        tmax=tmax,
        baseline=baseline,
        mode="constant",
    )
    data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
    baseline_mean, _ = raw[:, (tidx + b_start) : (tidx + b_end)]
    assert baseline_mean.mean(axis=1)[0] == data[0][0]

    # get epochs from raw with fixed data
    tmin, tmax, event_id = -0.2, 0.5, 1
    epochs = Epochs(
        raw,
        events,
        event_id,
        tmin,
        tmax,
        picks=picks,
        preload=True,
        reject=None,
        baseline=None,
    )
    e_start = int(np.ceil(epochs.info["sfreq"] * epochs.tmin))
    tmin_samp = int(-0.035 * epochs.info["sfreq"]) - e_start
    tmax_samp = int(-0.015 * epochs.info["sfreq"]) - e_start
    data_from_raw_fix = epochs.get_data(copy=False)[:, :, tmin_samp:tmax_samp]
    assert np.all(data_from_raw_fix) == 0.0

    # use window after stimulus
    evoked = epochs.average()
    tmin, tmax = 0.005, 0.045
    tmin_samp = int(0.015 * evoked.info["sfreq"]) - evoked.first
    tmax_samp = int(0.035 * evoked.info["sfreq"]) - evoked.first

    evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode="linear")
    data = evoked.data[:, tmin_samp:tmax_samp]
    diff_data0 = np.diff(data[0])
    diff_data0 -= np.mean(diff_data0)
    assert_array_almost_equal(diff_data0, np.zeros(len(diff_data0)))

    evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode="window")
    data = evoked.data[:, tmin_samp:tmax_samp]
    assert np.all(data) == 0.0

    evoked = fix_stim_artifact(
        evoked, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
    )
    base_t1 = int(baseline[0] * evoked.info["sfreq"]) - evoked.first
    base_t2 = int(baseline[1] * evoked.info["sfreq"]) - evoked.first
    data = evoked.data[:, tmin_samp:tmax_samp]
    baseline_mean = evoked.data[:, base_t1:base_t2].mean(axis=1)[0]
    assert data[0][0] == baseline_mean