1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
|
# Author: Denis Engemann <denis.engemann@gmail.com>
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# Jean-Remi King <jeanremi.king@gmail.com>
#
# License: BSD (3-clause)
from collections import Counter
import numpy as np
from .mixin import TransformerMixin, EstimatorMixin
from .base import _set_cv
from ..utils import logger, verbose
from ..parallel import parallel_func
from .. import pick_types, pick_info
class EMS(TransformerMixin, EstimatorMixin):
"""Transformer to compute event-matched spatial filters.
This version of EMS [1]_ operates on the entire time course. No time
window needs to be specified. The result is a spatial filter at each
time point and a corresponding time course. Intuitively, the result
gives the similarity between the filter at each time point and the
data vector (sensors) at that time point.
.. note : EMS only works for binary classification.
Attributes
----------
filters_ : ndarray, shape (n_channels, n_times)
The set of spatial filters.
classes_ : ndarray, shape (n_classes,)
The target classes.
References
----------
.. [1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
multi-sensor data to a single time course that reveals experimental
effects", BMC Neuroscience 2013, 14:122
"""
def __repr__(self): # noqa: D105
if hasattr(self, 'filters_'):
return '<EMS: fitted with %i filters on %i classes.>' % (
len(self.filters_), len(self.classes_))
else:
return '<EMS: not fitted.>'
def fit(self, X, y):
"""Fit the spatial filters.
.. note : EMS is fitted on data normalized by channel type before the
fitting of the spatial filters.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_times)
The training data.
y : array of int, shape (n_epochs)
The target classes.
Returns
-------
self : returns and instance of self.
"""
classes = np.unique(y)
if len(classes) != 2:
raise ValueError('EMS only works for binary classification.')
self.classes_ = classes
filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0)
filters /= np.linalg.norm(filters, axis=0)[None, :]
self.filters_ = filters
return self
def transform(self, X):
"""Transform the data by the spatial filters.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_times)
The input data.
Returns
-------
X : array, shape (n_epochs, n_times)
The input data transformed by the spatial filters.
"""
Xt = np.sum(X * self.filters_, axis=1)
return Xt
@verbose
def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None,
cv=None):
"""Compute event-matched spatial filter on epochs.
This version of EMS [1]_ operates on the entire time course. No time
window needs to be specified. The result is a spatial filter at each
time point and a corresponding time course. Intuitively, the result
gives the similarity between the filter at each time point and the
data vector (sensors) at that time point.
.. note : EMS only works for binary classification.
.. note : The present function applies a leave-one-out cross-validation,
following Schurger et al's paper. However, we recommend using
a stratified k-fold cross-validation. Indeed, leave-one-out tends
to overfit and cannot be used to estimate the variance of the
prediction within a given fold.
.. note : Because of the leave-one-out, this function needs an equal
number of epochs in each of the two conditions.
Parameters
----------
epochs : instance of mne.Epochs
The epochs.
conditions : list of str | None, defaults to None
If a list of strings, strings must match the epochs.event_id's key as
well as the number of conditions supported by the objective_function.
If None keys in epochs.event_id are used.
picks : array-like of int | None, defaults to None
Channels to be included. If None only good data channels are used.
n_jobs : int, defaults to 1
Number of jobs to run in parallel.
verbose : bool, str, int, or None, defaults to self.verbose
If not None, override default verbose level (see :func:`mne.verbose`
and :ref:`Logging documentation <tut_logging>` for more).
cv : cross-validation object | str | None, defaults to LeaveOneOut
The cross-validation scheme.
Returns
-------
surrogate_trials : ndarray, shape (n_trials // 2, n_times)
The trial surrogates.
mean_spatial_filter : ndarray, shape (n_channels, n_times)
The set of spatial filters.
conditions : ndarray, shape (n_classes,)
The conditions used. Values correspond to original event ids.
References
----------
.. [1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
multi-sensor data to a single time course that reveals experimental
effects", BMC Neuroscience 2013, 14:122
"""
logger.info('...computing surrogate time series. This can take some time')
# Default to leave-one-out cv
cv = 'LeaveOneOut' if cv is None else cv
if picks is None:
picks = pick_types(epochs.info, meg=True, eeg=True)
if not len(set(Counter(epochs.events[:, 2]).values())) == 1:
raise ValueError('The same number of epochs is required by '
'this function. Please consider '
'`epochs.equalize_event_counts`')
if conditions is None:
conditions = epochs.event_id.keys()
epochs = epochs.copy()
else:
epochs = epochs[conditions]
epochs.drop_bad()
if len(conditions) != 2:
raise ValueError('Currently this function expects exactly 2 '
'conditions but you gave me %i' %
len(conditions))
ev = epochs.events[:, 2]
# Special care to avoid path dependent mappings and orders
conditions = list(sorted(conditions))
cond_idx = [np.where(ev == epochs.event_id[k])[0] for k in conditions]
info = pick_info(epochs.info, picks)
data = epochs.get_data()[:, picks]
# Scale (z-score) the data by channel type
# XXX the z-scoring is applied outside the CV, which is not standard.
for ch_type in ['mag', 'grad', 'eeg']:
if ch_type in epochs:
# FIXME should be applied to all sort of data channels
if ch_type == 'eeg':
this_picks = pick_types(info, meg=False, eeg=True)
else:
this_picks = pick_types(info, meg=ch_type, eeg=False)
data[:, this_picks] /= np.std(data[:, this_picks])
# Setup cross-validation. Need to use _set_cv to deal with sklearn
# deprecation of cv objects.
y = epochs.events[:, 2]
_, cv_splits = _set_cv(cv, 'classifier', X=y, y=y)
parallel, p_func, _ = parallel_func(_run_ems, n_jobs=n_jobs)
# FIXME this parallelization should be removed.
# 1) it's numpy computation so it's already efficient,
# 2) it duplicates the data in RAM,
# 3) the computation is already super fast.
out = parallel(p_func(_ems_diff, data, cond_idx, train, test)
for train, test in cv_splits)
surrogate_trials, spatial_filter = zip(*out)
surrogate_trials = np.array(surrogate_trials)
spatial_filter = np.mean(spatial_filter, axis=0)
return surrogate_trials, spatial_filter, epochs.events[:, 2]
def _ems_diff(data0, data1):
"""Compute the default diff objective function."""
return np.mean(data0, axis=0) - np.mean(data1, axis=0)
def _run_ems(objective_function, data, cond_idx, train, test):
"""Run EMS."""
d = objective_function(*(data[np.intersect1d(c, train)] for c in cond_idx))
d /= np.sqrt(np.sum(d ** 2, axis=0))[None, :]
# compute surrogates
return np.sum(data[test[0]] * d, axis=0), d
|