1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
|
# Author: Mathieu Scheltienne <mathieu.scheltienne@fcbg.ch>
#
# License: BSD-3-Clause
import numpy as np
from ..fixes import jit
from ..io import BaseRaw
from ..annotations import (Annotations, _adjust_onset_meas_date,
_annotations_starts_stops)
from ..io.pick import _picks_to_idx, _picks_by_type, _get_channel_types
from ..utils import _validate_type, verbose, logger, _mask_to_onsets_offsets
@verbose
def annotate_amplitude(raw, peak=None, flat=None, bad_percent=5,
min_duration=0.005, picks=None, *, verbose=None):
"""Annotate raw data based on peak-to-peak amplitude.
Creates annotations ``BAD_peak`` or ``BAD_flat`` for spans of data where
consecutive samples exceed the threshold in ``peak`` or fall below the
threshold in ``flat`` for more than ``min_duration``.
Channels where more than ``bad_percent`` of the total recording length
should be annotated with either ``BAD_peak`` or ``BAD_flat`` are returned
in ``bads`` instead.
Note that the annotations and the bads are not automatically added to the
:class:`~mne.io.Raw` object; use :meth:`~mne.io.Raw.set_annotations` and
:class:`info['bads'] <mne.Info>` to do so.
Parameters
----------
raw : instance of Raw
The raw data.
peak : float | dict | None
Annotate segments based on **maximum** peak-to-peak signal amplitude
(PTP). Valid **keys** can be any channel type present in the object.
The **values** are floats that set the maximum acceptable PTP. If the
PTP is larger than this threshold, the segment will be annotated.
If float, the minimum acceptable PTP is applied to all channels.
flat : float | dict | None
Annotate segments based on **minimum** peak-to-peak signal amplitude
(PTP). Valid **keys** can be any channel type present in the object.
The **values** are floats that set the minimum acceptable PTP. If the
PTP is smaller than this threshold, the segment will be annotated.
If float, the minimum acceptable PTP is applied to all channels.
bad_percent : float
The percentage of the time a channel can be above or below thresholds.
Below this percentage, :class:`~mne.Annotations` are created.
Above this percentage, the channel involved is return in ``bads``. Note
the returned ``bads`` are not automatically added to
:class:`info['bads'] <mne.Info>`.
Defaults to ``5``, i.e. 5%%.
min_duration : float
The minimum duration (sec) required by consecutives samples to be above
``peak`` or below ``flat`` thresholds to be considered.
to consider as above or below threshold.
For some systems, adjacent time samples with exactly the same value are
not totally uncommon. Defaults to ``0.005`` (5 ms).
%(picks_good_data)s
%(verbose)s
Returns
-------
annotations : instance of Annotations
The annotated bad segments.
bads : list
The channels detected as bad.
Notes
-----
This function does not use a window to detect small peak-to-peak or large
peak-to-peak amplitude changes as the ``reject`` and ``flat`` argument from
:class:`~mne.Epochs` does. Instead, it looks at the difference between
consecutive samples.
- When used to detect segments below ``flat``, at least ``min_duration``
seconds of consecutive samples must respect
``abs(a[i+1] - a[i]) ≤ flat``.
- When used to detect segments above ``peak``, at least ``min_duration``
seconds of consecutive samples must respect
``abs(a[i+1] - a[i]) ≥ peak``.
Thus, this function does not detect every temporal event with large
peak-to-peak amplitude, but only the ones where the peak-to-peak amplitude
is supra-threshold between consecutive samples. For instance, segments
experiencing a DC shift will not be picked up. Only the edges from the DC
shift will be annotated (and those only if the edge transitions are longer
than ``min_duration``).
This function may perform faster if data is loaded in memory, as it
loads data one channel type at a time (across all time points), which is
typically not an efficient way to read raw data from disk.
.. versionadded:: 1.0
"""
_validate_type(raw, BaseRaw, 'raw')
picks_ = _picks_to_idx(raw.info, picks, 'data_or_ica', exclude='bads')
peak = _check_ptp(peak, 'peak', raw.info, picks_)
flat = _check_ptp(flat, 'flat', raw.info, picks_)
if peak is None and flat is None:
raise ValueError(
"At least one of the arguments 'peak' or 'flat' must not be None.")
bad_percent = _check_bad_percent(bad_percent)
min_duration = _check_min_duration(min_duration,
raw.times.size * 1 / raw.info['sfreq'])
min_duration_samples = int(np.round(min_duration * raw.info['sfreq']))
bads = list()
# grouping picks by channel types to avoid operating on each channel
# individually
picks = {
ch_type: np.intersect1d(picks_of_type, picks_, assume_unique=True)
for ch_type, picks_of_type in _picks_by_type(raw.info, exclude='bads')
if np.intersect1d(picks_of_type, picks_, assume_unique=True).size != 0
}
del picks_ # re-using this variable name in for loop
# skip BAD_acq_skip sections
onsets, ends = _annotations_starts_stops(raw, 'bad_acq_skip', invert=True)
index = np.concatenate([np.arange(raw.times.size)[onset:end]
for onset, end in zip(onsets, ends)])
# size matching the diff a[i+1] - a[i]
any_flat = np.zeros(len(raw.times) - 1, bool)
any_peak = np.zeros(len(raw.times) - 1, bool)
# look for discrete difference above or below thresholds
logger.info('Finding segments below or above PTP threshold.')
for ch_type, picks_ in picks.items():
data = np.concatenate([raw[picks_, onset:end][0]
for onset, end in zip(onsets, ends)], axis=1)
diff = np.abs(np.diff(data, axis=1))
if flat is not None:
flat_ = diff <= flat[ch_type]
# reject too short segments
flat_ = _reject_short_segments(flat_, min_duration_samples)
# reject channels above maximum bad_percentage
flat_count = flat_.sum(axis=1)
flat_count[np.nonzero(flat_count)] += 1 # offset by 1 due to diff
flat_mean = flat_count / raw.times.size * 100
flat_ch_to_set_bad = picks_[np.where(flat_mean >= bad_percent)[0]]
bads.extend(flat_ch_to_set_bad)
# add onset/offset for annotations
flat_ch_to_annotate = \
np.where((0 < flat_mean) & (flat_mean < bad_percent))[0]
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
idx = index[np.where(flat_[flat_ch_to_annotate, :])[1]]
any_flat[idx] = True
if peak is not None:
peak_ = diff >= peak[ch_type]
# reject too short segments
peak_ = _reject_short_segments(peak_, min_duration_samples)
# reject channels above maximum bad_percentage
peak_count = peak_.sum(axis=1)
peak_count[np.nonzero(peak_count)] += 1 # offset by 1 due to diff
peak_mean = peak_count / raw.times.size * 100
peak_ch_to_set_bad = picks_[np.where(peak_mean >= bad_percent)[0]]
bads.extend(peak_ch_to_set_bad)
# add onset/offset for annotations
peak_ch_to_annotate = \
np.where((0 < peak_mean) & (peak_mean < bad_percent))[0]
# convert from raw.times[onset:end] - 1 to raw.times[:] - 1
idx = index[np.where(peak_[peak_ch_to_annotate, :])[1]]
any_peak[idx] = True
# annotation for flat
annotation_flat = _create_annotations(any_flat, 'flat', raw)
# annotation for peak
annotation_peak = _create_annotations(any_peak, 'peak', raw)
# group
annotations = annotation_flat + annotation_peak
# bads
bads = [raw.ch_names[bad] for bad in bads if bad not in raw.info['bads']]
return annotations, bads
def _check_ptp(ptp, name, info, picks):
"""Check the PTP threhsold argument, and converts it to dict if needed."""
_validate_type(ptp, ('numeric', dict, None))
if ptp is not None and not isinstance(ptp, dict):
if ptp < 0:
raise ValueError(
f"Argument '{name}' should define a positive threshold. "
f"Provided: '{ptp}'.")
ch_types = set(_get_channel_types(info, picks))
ptp = {ch_type: ptp for ch_type in ch_types}
elif isinstance(ptp, dict):
for key, value in ptp.items():
if value < 0:
raise ValueError(
f"Argument '{name}' should define positive thresholds. "
f"Provided for channel type '{key}': '{value}'.")
return ptp
def _check_bad_percent(bad_percent):
"""Check that bad_percent is a valid percentage and converts to float."""
_validate_type(bad_percent, 'numeric', 'bad_percent')
bad_percent = float(bad_percent)
if not 0 <= bad_percent <= 100:
raise ValueError(
"Argument 'bad_percent' should define a percentage between 0% "
f"and 100%. Provided: {bad_percent}%.")
return bad_percent
def _check_min_duration(min_duration, raw_duration):
"""Check that min_duration is a valid duration and converts to float."""
_validate_type(min_duration, 'numeric', 'min_duration')
min_duration = float(min_duration)
if min_duration < 0:
raise ValueError(
"Argument 'min_duration' should define a positive duration in "
f"seconds. Provided: '{min_duration}' seconds.")
if min_duration >= raw_duration:
raise ValueError(
"Argument 'min_duration' should define a positive duration in "
f"seconds shorter than the raw duration ({raw_duration} seconds). "
f"Provided: '{min_duration}' seconds.")
return min_duration
def _reject_short_segments(arr, min_duration_samples):
"""Check if flat or peak segments are longer than the minimum duration."""
assert arr.dtype == bool and arr.ndim == 2
for k, ch in enumerate(arr):
onsets, offsets = _mask_to_onsets_offsets(ch)
_mark_inner(arr[k], onsets, offsets, min_duration_samples)
return arr
@jit()
def _mark_inner(arr_k, onsets, offsets, min_duration_samples):
"""Inner loop of _reject_short_segments()."""
for start, stop in zip(onsets, offsets):
if stop - start < min_duration_samples:
arr_k[start:stop] = False
def _create_annotations(any_arr, kind, raw):
"""Create the peak of flat annotations from the any_arr."""
assert kind in ('peak', 'flat')
starts, stops = _mask_to_onsets_offsets(any_arr)
starts, stops = np.array(starts), np.array(stops)
onsets = starts / raw.info['sfreq']
durations = (stops - starts) / raw.info['sfreq']
annot = Annotations(onsets, durations, [f'BAD_{kind}'] * len(onsets),
orig_time=raw.info['meas_date'])
_adjust_onset_meas_date(annot, raw)
return annot
|