File: test_annotate_amplitude.py

package info (click to toggle)
python-mne 1.3.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 100,172 kB
  • sloc: python: 166,349; pascal: 3,602; javascript: 1,472; sh: 334; makefile: 236
file content (394 lines) | stat: -rw-r--r-- 16,764 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
# Author: Mathieu Scheltienne <mathieu.scheltienne@fcbg.ch>
#
# License: BSD-3-Clause

import datetime
import itertools
from pathlib import Path
import re

import numpy as np
import pytest

from mne import create_info
from mne.annotations import Annotations
from mne.datasets import testing
from mne.io import RawArray, read_raw_fif
from mne.preprocessing import annotate_amplitude


date = datetime.datetime(2021, 12, 10, 7, 52, 24, 405305,
                         tzinfo=datetime.timezone.utc)
data_path = Path(testing.data_path(download=False))
skip_fname = data_path / 'misc' / 'intervalrecording_raw.fif'


@pytest.mark.parametrize('meas_date', (None, date))
@pytest.mark.parametrize('first_samp', (0, 10000))
def test_annotate_amplitude(meas_date, first_samp):
    """Test automatic annotation for segments based on peak-to-peak value."""
    n_ch, n_times = 11, 1000
    data = np.random.RandomState(0).randn(n_ch, n_times)
    assert not (np.diff(data, axis=-1) == 0).any()  # nothing flat at first
    info = create_info(n_ch, 1000., 'eeg')
    # from annotate_flat: test first_samp != for gh-6295
    raw = RawArray(data, info, first_samp=first_samp)
    raw.info['bads'] = [raw.ch_names[-1]]
    raw.set_meas_date(meas_date)

    # -- test bad channels spatial marking --
    for perc, dur in itertools.product((5, 99.9, 100.), (0.005, 0.95, 0.99)):
        kwargs = dict(bad_percent=perc, min_duration=dur)

        # test entire channel flat
        raw_ = raw.copy()
        raw_._data[0] = 0.
        annots, bads = annotate_amplitude(raw_, peak=None, flat=0., **kwargs)
        assert len(annots) == 0
        assert bads == ['0']

        # test multiple channels flat
        raw_ = raw.copy()
        raw_._data[0] = 0.
        raw_._data[2] = 0.
        annots, bads = annotate_amplitude(raw_, peak=None, flat=0., **kwargs)
        assert len(annots) == 0
        assert bads == ['0', '2']

        # test entire channel drifting
        raw_ = raw.copy()
        raw_._data[0] = np.arange(0, raw.times.size * 10, 10)
        annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs)
        assert len(annots) == 0
        assert bads == ['0']

        # test multiple channels drifting
        raw_ = raw.copy()
        raw_._data[0] = np.arange(0, raw.times.size * 10, 10)
        raw_._data[2] = np.arange(0, raw.times.size * 10, 10)
        annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs)
        assert len(annots) == 0
        assert bads == ['0', '2']

    # -- test bad channels temporal marking --
    # flat channel for the 20% last points
    n_good_times = int(round(0.8 * n_times))
    raw_ = raw.copy()
    raw_._data[0, n_good_times:] = 0.
    for perc in (5, 20):
        annots, bads = annotate_amplitude(raw_, peak=None, flat=0.,
                                          bad_percent=perc)
        assert len(annots) == 0
        assert bads == ['0']
    annots, bads = annotate_amplitude(raw_, peak=None, flat=0.,
                                      bad_percent=20.1)
    assert len(annots) == 1
    assert len(bads) == 0
    # check annotation instance
    assert annots[0]['description'] == 'BAD_flat'
    _check_annotation(raw_, annots[0], meas_date, first_samp, n_good_times, -1)

    # test multiple channels flat and multiple channels drift
    raw_ = raw.copy()
    raw_._data[0, 800:] = 0.
    raw_._data[1, 850:950] = 0.
    raw_._data[2, :200] = np.arange(0, 200 * 10, 10)
    raw_._data[2, 200:] += raw_._data[2, 199]  # add offset for next samples
    raw_._data[3, 50:150] = np.arange(0, 100 * 10, 10)
    raw_._data[3, 150:] += raw_._data[3, 149]  # add offset for next samples
    for perc in (5, 10):
        annots, bads = annotate_amplitude(raw_, peak=5, flat=0.,
                                          bad_percent=perc)
        assert len(annots) == 0
        assert bads == ['0', '1', '2', '3']
    for perc in (10.1, 20):
        annots, bads = annotate_amplitude(raw_, peak=5, flat=0.,
                                          bad_percent=perc)
        assert len(annots) == 2
        assert bads == ['0', '2']
        # check annotation instance
        assert all(annot['description'] in ('BAD_flat', 'BAD_peak')
                   for annot in annots)
        for annot in annots:
            start_idx = 50 if annot['description'] == 'BAD_peak' else 850
            stop_idx = 149 if annot['description'] == 'BAD_peak' else 949
            _check_annotation(raw_, annot, meas_date, first_samp, start_idx,
                              stop_idx)
    annots, bads = annotate_amplitude(raw_, peak=5, flat=0., bad_percent=20.1)
    assert len(annots) == 2
    assert len(bads) == 0
    # check annotation instance
    assert all(annot['description'] in ('BAD_flat', 'BAD_peak')
               for annot in annots)
    for annot in annots:
        start_idx = 0 if annot['description'] == 'BAD_peak' else 800
        stop_idx = 199 if annot['description'] == 'BAD_peak' else -1
        _check_annotation(raw_, annot, meas_date, first_samp, start_idx,
                          stop_idx)

    # test flat on already marked bad channel
    raw_ = raw.copy()
    raw_._data[-1, :] = 0.  # this channel is already in info['bads']
    annots, bads = annotate_amplitude(raw_, peak=None, flat=0., bad_percent=5)
    assert len(annots) == 0
    assert len(bads) == 0

    # test drift on already marked bad channel
    raw_ = raw.copy()
    raw_._data[-1, :] = np.arange(0, raw.times.size * 10, 10)
    annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=5)
    assert len(annots) == 0
    assert len(bads) == 0


@pytest.mark.parametrize('meas_date', (None, date))
@pytest.mark.parametrize('first_samp', (0, 10000))
def test_annotate_amplitude_with_overlap(meas_date, first_samp):
    """Test cases with overlap between annotations."""
    n_ch, n_times = 11, 1000
    data = np.random.RandomState(0).randn(n_ch, n_times)
    assert not (np.diff(data, axis=-1) == 0).any()  # nothing flat at first
    info = create_info(n_ch, 1000., 'eeg')
    # from annotate_flat: test first_samp != for gh-6295
    raw = RawArray(data, info, first_samp=first_samp)
    raw.info['bads'] = [raw.ch_names[-1]]
    raw.set_meas_date(meas_date)

    # -- overlap between peak and flat --
    raw_ = raw.copy()
    raw_._data[0, 800:] = 0.
    raw_._data[1, 700:900] = np.arange(0, 200 * 10, 10)
    raw_._data[1, 900:] += raw_._data[1, 899]  # add offset for next samples
    annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=25)
    assert len(annots) == 2
    assert len(bads) == 0
    # check annotation instance
    assert all(annot['description'] in ('BAD_flat', 'BAD_peak')
               for annot in annots)
    for annot in annots:
        start_idx = 700 if annot['description'] == 'BAD_peak' else 800
        stop_idx = 899 if annot['description'] == 'BAD_peak' else -1
        _check_annotation(raw_, annot, meas_date, first_samp, start_idx,
                          stop_idx)

    # -- overlap between peak and peak on same channel --
    raw_ = raw.copy()
    raw_._data[0, 700:900] = np.arange(0, 200 * 10, 10)
    raw_._data[0, 800:] = np.arange(1000, 300 * 10, 10)
    annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=50)
    assert len(annots) == 1
    assert len(bads) == 0
    # check annotation instance
    assert annots[0]['description'] == 'BAD_peak'
    _check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1)

    # -- overlap between flat and flat on different channel --
    raw_ = raw.copy()
    raw_._data[0, 700:900] = 0.
    raw_._data[1, 800:] = 0.
    annots, bads = annotate_amplitude(raw_, peak=None, flat=0.01,
                                      bad_percent=50)
    assert len(annots) == 1
    assert len(bads) == 0
    # check annotation instance
    assert annots[0]['description'] == 'BAD_flat'
    _check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1)


@pytest.mark.parametrize('meas_date', (None, date))
@pytest.mark.parametrize('first_samp', (0, 10000))
def test_annotate_amplitude_multiple_ch_types(meas_date, first_samp):
    """Test cases with several channel types."""
    n_ch, n_times = 11, 1000
    data = np.random.RandomState(0).randn(n_ch, n_times)
    assert not (np.diff(data, axis=-1) == 0).any()  # nothing flat at first
    info = create_info(n_ch, 1000.,
                       ['eeg'] * 3 + ['mag'] * 2 + ['grad'] * 4 + ['eeg'] * 2)
    # from annotate_flat: test first_samp != for gh-6295
    raw = RawArray(data, info, first_samp=first_samp)
    raw.info['bads'] = [raw.ch_names[-1]]
    raw.set_meas_date(meas_date)

    # -- 2 channel types both to annotate --
    raw_ = raw.copy()
    raw_._data[1, 800:] = 0.
    raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
    raw_._data[5, 200:] += raw_._data[5, 199]  # add offset for next samples
    annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50)
    assert len(annots) == 2
    assert len(bads) == 0
    # check annotation instance
    assert all(annot['description'] in ('BAD_flat', 'BAD_peak')
               for annot in annots)
    for annot in annots:
        start_idx = 0 if annot['description'] == 'BAD_peak' else 800
        stop_idx = 199 if annot['description'] == 'BAD_peak' else -1
        _check_annotation(raw_, annot, meas_date, first_samp, start_idx,
                          stop_idx)

    # -- 2 channel types, one flat picked, one not picked --
    raw_ = raw.copy()
    raw_._data[1, 800:] = 0.
    raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
    raw_._data[5, 200:] += raw_._data[5, 199]  # add offset for next samples
    annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50,
                                      picks='eeg')
    assert len(annots) == 1
    assert len(bads) == 0
    # check annotation instance
    _check_annotation(raw_, annots[0], meas_date, first_samp, 800, -1)
    assert annots[0]['description'] == 'BAD_flat'

    # -- 2 channel types, one flat, one not picked, reverse --
    raw_ = raw.copy()
    raw_._data[1, 800:] = 0.
    raw_._data[5, :200] = np.arange(0, 200 * 10, 10)
    raw_._data[5, 200:] += raw_._data[5, 199]  # add offset for next samples
    annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50,
                                      picks='grad')
    assert len(annots) == 1
    assert len(bads) == 0
    # check annotation instance
    _check_annotation(raw_, annots[0], meas_date, first_samp, 0, 199)
    assert annots[0]['description'] == 'BAD_peak'


@testing.requires_testing_data
def test_flat_bad_acq_skip():
    """Test that acquisition skips are handled properly."""
    # -- file with a couple of skip and flat channels --
    raw = read_raw_fif(skip_fname, preload=True)
    annots, bads = annotate_amplitude(raw, flat=0)
    assert len(annots) == 0
    assert bads == [  # MaxFilter finds the same 21 channels
        'MEG%04d' % (int(num),) for num in
        '141 331 421 431 611 641 1011 1021 1031 1241 1421 '
        '1741 1841 2011 2131 2141 2241 2531 2541 2611 2621'.split()]

    # -- overlap of flat segment with bad_acq_skip --
    n_ch, n_times = 11, 1000
    data = np.random.RandomState(0).randn(n_ch, n_times)
    assert not (np.diff(data, axis=-1) == 0).any()  # nothing flat at first
    info = create_info(n_ch, 1000., 'eeg')
    raw = RawArray(data, info, first_samp=0)
    raw.info['bads'] = [raw.ch_names[-1]]
    bad_acq_skip = Annotations([0.5], [0.2], ['bad_acq_skip'], orig_time=None)
    raw.set_annotations(bad_acq_skip)
    # add flat channel overlapping with the left edge of bad_acq_skip
    raw_ = raw.copy()
    raw_._data[0, 400:600] = 0.
    annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25)
    assert len(annots) == 1
    assert len(bads) == 0
    # check annotation instance
    assert annots[0]['description'] == 'BAD_flat'
    _check_annotation(raw_, annots[0], None, 0, 400, 499)

    # add flat channel overlapping with the right edge of bad_acq_skip
    raw_ = raw.copy()
    raw_._data[0, 600:800] = 0.
    annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25)
    assert len(annots) == 1
    assert len(bads) == 0
    # check annotation instance
    assert annots[0]['description'] == 'BAD_flat'
    _check_annotation(raw_, annots[0], None, 0, 700, 799)

    # add flat channel overlapping entirely with bad_acq_skip
    raw_ = raw.copy()
    raw_._data[0, 200:800] = 0.
    annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=41)
    assert len(annots) == 2
    assert len(bads) == 0
    # check annotation instance
    annots = sorted(annots, key=lambda x: x['onset'])
    assert all(annot['description'] == 'BAD_flat' for annot in annots)
    _check_annotation(raw_, annots[0], None, 0, 200, 500)
    _check_annotation(raw_, annots[1], None, 0, 700, 799)


def _check_annotation(raw, annot, meas_date, first_samp, start_idx, stop_idx):
    """Util function to check an annotation."""
    assert meas_date == annot['orig_time']
    if meas_date is None:
        assert np.isclose(raw.times[start_idx], annot['onset'], atol=1e-4)
        assert np.isclose(
            raw.times[stop_idx], annot['onset'] + annot['duration'], atol=1e-4)
    else:
        first_time = first_samp / raw.info['sfreq']  # because of meas_date
        assert np.isclose(
            raw.times[start_idx], annot['onset'] - first_time, atol=1e-4)
        assert np.isclose(
            raw.times[stop_idx],
            annot['onset'] + annot['duration'] - first_time,
            atol=1e-4)


def test_invalid_arguments():
    """Test error messages raised by invalid arguments."""
    n_ch, n_times = 2, 100
    data = np.random.RandomState(0).randn(n_ch, n_times)
    info = create_info(n_ch, 100., 'eeg')
    raw = RawArray(data, info, first_samp=0)

    # negative floats PTP
    with pytest.raises(ValueError,
                       match="Argument 'flat' should define a positive "
                             "threshold. Provided: '-1'."):
        annotate_amplitude(raw, peak=None, flat=-1)
    with pytest.raises(ValueError,
                       match="Argument 'peak' should define a positive "
                             "threshold. Provided: '-1'."):
        annotate_amplitude(raw, peak=-1, flat=None)

    # negative PTP threshold for one channel type
    with pytest.raises(ValueError,
                       match="Argument 'flat' should define positive "
                             "thresholds. Provided for channel type "
                             "'eog': '-1'."):
        annotate_amplitude(raw, peak=None, flat=dict(eeg=1, eog=-1))
    with pytest.raises(ValueError,
                       match="Argument 'peak' should define positive "
                             "thresholds. Provided for channel type "
                             "'eog': '-1'."):
        annotate_amplitude(raw, peak=dict(eeg=1, eog=-1), flat=None)

    # test both PTP set to None
    with pytest.raises(ValueError,
                       match="At least one of the arguments 'peak' or 'flat' "
                             "must not be None."):
        annotate_amplitude(raw, peak=None, flat=None)

    # bad_percent outside [0, 100]
    with pytest.raises(ValueError,
                       match="Argument 'bad_percent' should define a "
                             "percentage between 0% and 100%. Provided: "
                             "-1.0%."):
        annotate_amplitude(raw, peak=dict(eeg=1), flat=None, bad_percent=-1)

    # min_duration negative
    with pytest.raises(ValueError,
                       match="Argument 'min_duration' should define a "
                             "positive duration in seconds. Provided: "
                             "'-1.0' seconds."):
        annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=-1)

    # min_duration equal to the raw duration
    with pytest.raises(
            ValueError,
            match=re.escape("Argument 'min_duration' should define a "
                            "positive duration in seconds shorter than the "
                            "raw duration (1.0 seconds). Provided: "
                            "'1.0' seconds.")
    ):
        annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=1.)

    # min_duration longer than the raw duration
    with pytest.raises(
            ValueError,
            match=re.escape("Argument 'min_duration' should define a "
                            "positive duration in seconds shorter than the "
                            "raw duration (1.0 seconds). Provided: "
                            "'10.0' seconds.")
    ):
        annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=10)