1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
|
"""Utility functions to baseline-correct data."""
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from .utils import _check_option, _validate_type, logger, verbose
def _log_rescale(baseline, mode="mean"):
"""Log the rescaling method."""
if baseline is not None:
_check_option(
"mode",
mode,
["logratio", "ratio", "zscore", "mean", "percent", "zlogratio"],
)
msg = f"Applying baseline correction (mode: {mode})"
else:
msg = "No baseline correction applied"
return msg
@verbose
def rescale(data, times, baseline, mode="mean", copy=True, picks=None, verbose=None):
"""Rescale (baseline correct) data.
Parameters
----------
data : array
It can be of any shape. The only constraint is that the last
dimension should be time.
times : 1D array
Time instants is seconds.
%(baseline_rescale)s
mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio'
Perform baseline correction by
- subtracting the mean of baseline values ('mean')
- dividing by the mean of baseline values ('ratio')
- dividing by the mean of baseline values and taking the log
('logratio')
- subtracting the mean of baseline values followed by dividing by
the mean of baseline values ('percent')
- subtracting the mean of baseline values and dividing by the
standard deviation of baseline values ('zscore')
- dividing by the mean of baseline values, taking the log, and
dividing by the standard deviation of log baseline values
('zlogratio')
copy : bool
Whether to return a new instance or modify in place.
picks : list of int | None
Data to process along the axis=-2 (None, default, processes all).
%(verbose)s
Returns
-------
data_scaled: array
Array of same shape as data after rescaling.
"""
if copy:
data = data.copy()
if verbose is not False:
msg = _log_rescale(baseline, mode)
logger.info(msg)
if baseline is None or data.shape[-1] == 0:
return data
bmin, bmax = baseline
if bmin is None:
imin = 0
else:
imin = np.where(times >= bmin)[0]
if len(imin) == 0:
raise ValueError(
f"bmin is too large ({bmin}), it exceeds the largest time value"
)
imin = int(imin[0])
if bmax is None:
imax = len(times)
else:
imax = np.where(times <= bmax)[0]
if len(imax) == 0:
raise ValueError(
f"bmax is too small ({bmax}), it is smaller than the smallest time "
"value"
)
imax = int(imax[-1]) + 1
if imin >= imax:
raise ValueError(
f"Bad rescaling slice ({imin}:{imax}) from time values {bmin}, {bmax}"
)
# technically this is inefficient when `picks` is given, but assuming
# that we generally pick most channels for rescaling, it's not so bad
mean = np.mean(data[..., imin:imax], axis=-1, keepdims=True)
if mode == "mean":
def fun(d, m):
d -= m
elif mode == "ratio":
def fun(d, m):
d /= m
elif mode == "logratio":
def fun(d, m):
d /= m
np.log10(d, out=d)
elif mode == "percent":
def fun(d, m):
d -= m
d /= m
elif mode == "zscore":
def fun(d, m):
d -= m
d /= np.std(d[..., imin:imax], axis=-1, keepdims=True)
elif mode == "zlogratio":
def fun(d, m):
d /= m
np.log10(d, out=d)
d /= np.std(d[..., imin:imax], axis=-1, keepdims=True)
if picks is None:
fun(data, mean)
else:
for pi in picks:
fun(data[..., pi, :], mean[..., pi, :])
return data
def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"):
"""Check if the baseline is valid and adjust it if requested.
``None`` values inside ``baseline`` will be replaced with ``times[0]`` and
``times[-1]``.
Parameters
----------
baseline : array-like, shape (2,) | None
Beginning and end of the baseline period, in seconds. If ``None``,
assume no baseline and return immediately.
times : array
The time points.
sfreq : float
The sampling rate.
on_baseline_outside_data : 'raise' | 'info' | 'adjust'
What to do if the baseline period exceeds the data.
If ``'raise'``, raise an exception (default).
If ``'info'``, log an info message.
If ``'adjust'``, adjust the baseline such that it is within the data range.
Returns
-------
(baseline_tmin, baseline_tmax) | None
The baseline with ``None`` values replaced with times, and with adjusted times
if ``on_baseline_outside_data='adjust'``; or ``None``, if ``baseline`` is
``None``.
"""
if baseline is None:
return None
_validate_type(baseline, "array-like")
baseline = tuple(baseline)
if len(baseline) != 2:
raise ValueError(
f"baseline must have exactly two elements (got {len(baseline)})."
)
tmin, tmax = times[0], times[-1]
tstep = 1.0 / float(sfreq)
# check default value of baseline and `tmin=0`
if baseline == (None, 0) and tmin == 0:
raise ValueError(
"Baseline interval is only one sample. Use `baseline=(0, 0)` if this is "
"desired."
)
baseline_tmin, baseline_tmax = baseline
if baseline_tmin is None:
baseline_tmin = tmin
baseline_tmin = float(baseline_tmin)
if baseline_tmax is None:
baseline_tmax = tmax
baseline_tmax = float(baseline_tmax)
if baseline_tmin > baseline_tmax:
raise ValueError(
f"Baseline min ({baseline_tmin}) must be less than baseline max ("
f"{baseline_tmax})"
)
if (baseline_tmin < tmin - tstep) or (baseline_tmax > tmax + tstep):
msg = (
f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s is outside of "
f"epochs data [{tmin}, {tmax}] s. Epochs were probably cropped."
)
if on_baseline_outside_data == "raise":
raise ValueError(msg)
elif on_baseline_outside_data == "info":
logger.info(msg)
elif on_baseline_outside_data == "adjust":
if baseline_tmin < tmin - tstep:
baseline_tmin = tmin
if baseline_tmax > tmax + tstep:
baseline_tmax = tmax
return baseline_tmin, baseline_tmax
|