File: ems.py

package info (click to toggle)
python-mne 0.17%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 95,104 kB
  • sloc: python: 110,639; makefile: 222; sh: 15
file content (223 lines) | stat: -rw-r--r-- 8,295 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Author: Denis Engemann <denis.engemann@gmail.com>
#         Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#         Jean-Remi King <jeanremi.king@gmail.com>
#
# License: BSD (3-clause)

from collections import Counter

import numpy as np

from .mixin import TransformerMixin, EstimatorMixin
from .base import _set_cv
from ..utils import logger, verbose
from ..parallel import parallel_func
from .. import pick_types, pick_info


class EMS(TransformerMixin, EstimatorMixin):
    """Transformer to compute event-matched spatial filters.

    This version of EMS [1]_ operates on the entire time course. No time
    window needs to be specified. The result is a spatial filter at each
    time point and a corresponding time course. Intuitively, the result
    gives the similarity between the filter at each time point and the
    data vector (sensors) at that time point.

    .. note : EMS only works for binary classification.

    Attributes
    ----------
    filters_ : ndarray, shape (n_channels, n_times)
        The set of spatial filters.
    classes_ : ndarray, shape (n_classes,)
        The target classes.

    References
    ----------
    .. [1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
           multi-sensor data to a single time course that reveals experimental
           effects", BMC Neuroscience 2013, 14:122
    """

    def __repr__(self):  # noqa: D105
        if hasattr(self, 'filters_'):
            return '<EMS: fitted with %i filters on %i classes.>' % (
                len(self.filters_), len(self.classes_))
        else:
            return '<EMS: not fitted.>'

    def fit(self, X, y):
        """Fit the spatial filters.

        .. note : EMS is fitted on data normalized by channel type before the
                  fitting of the spatial filters.

        Parameters
        ----------
        X : array, shape (n_epochs, n_channels, n_times)
            The training data.
        y : array of int, shape (n_epochs)
            The target classes.

        Returns
        -------
        self : returns and instance of self.
        """
        classes = np.unique(y)
        if len(classes) != 2:
            raise ValueError('EMS only works for binary classification.')
        self.classes_ = classes
        filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0)
        filters /= np.linalg.norm(filters, axis=0)[None, :]
        self.filters_ = filters
        return self

    def transform(self, X):
        """Transform the data by the spatial filters.

        Parameters
        ----------
        X : array, shape (n_epochs, n_channels, n_times)
            The input data.

        Returns
        -------
        X : array, shape (n_epochs, n_times)
            The input data transformed by the spatial filters.
        """
        Xt = np.sum(X * self.filters_, axis=1)
        return Xt


@verbose
def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None,
                cv=None):
    """Compute event-matched spatial filter on epochs.

    This version of EMS [1]_ operates on the entire time course. No time
    window needs to be specified. The result is a spatial filter at each
    time point and a corresponding time course. Intuitively, the result
    gives the similarity between the filter at each time point and the
    data vector (sensors) at that time point.

    .. note : EMS only works for binary classification.

    .. note : The present function applies a leave-one-out cross-validation,
              following Schurger et al's paper. However, we recommend using
              a stratified k-fold cross-validation. Indeed, leave-one-out tends
              to overfit and cannot be used to estimate the variance of the
              prediction within a given fold.

    .. note : Because of the leave-one-out, this function needs an equal
              number of epochs in each of the two conditions.

    Parameters
    ----------
    epochs : instance of mne.Epochs
        The epochs.
    conditions : list of str | None, defaults to None
        If a list of strings, strings must match the epochs.event_id's key as
        well as the number of conditions supported by the objective_function.
        If None keys in epochs.event_id are used.
    picks : array-like of int | None, defaults to None
        Channels to be included. If None only good data channels are used.
    n_jobs : int, defaults to 1
        Number of jobs to run in parallel.
    verbose : bool, str, int, or None, defaults to self.verbose
        If not None, override default verbose level (see :func:`mne.verbose`
        and :ref:`Logging documentation <tut_logging>` for more).
    cv : cross-validation object | str | None, defaults to LeaveOneOut
        The cross-validation scheme.

    Returns
    -------
    surrogate_trials : ndarray, shape (n_trials // 2, n_times)
        The trial surrogates.
    mean_spatial_filter : ndarray, shape (n_channels, n_times)
        The set of spatial filters.
    conditions : ndarray, shape (n_classes,)
        The conditions used. Values correspond to original event ids.

    References
    ----------
    .. [1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
           multi-sensor data to a single time course that reveals experimental
           effects", BMC Neuroscience 2013, 14:122
    """
    logger.info('...computing surrogate time series. This can take some time')

    # Default to leave-one-out cv
    cv = 'LeaveOneOut' if cv is None else cv

    if picks is None:
        picks = pick_types(epochs.info, meg=True, eeg=True)

    if not len(set(Counter(epochs.events[:, 2]).values())) == 1:
        raise ValueError('The same number of epochs is required by '
                         'this function. Please consider '
                         '`epochs.equalize_event_counts`')

    if conditions is None:
        conditions = epochs.event_id.keys()
        epochs = epochs.copy()
    else:
        epochs = epochs[conditions]

    epochs.drop_bad()

    if len(conditions) != 2:
        raise ValueError('Currently this function expects exactly 2 '
                         'conditions but you gave me %i' %
                         len(conditions))

    ev = epochs.events[:, 2]
    # Special care to avoid path dependent mappings and orders
    conditions = list(sorted(conditions))
    cond_idx = [np.where(ev == epochs.event_id[k])[0] for k in conditions]

    info = pick_info(epochs.info, picks)
    data = epochs.get_data()[:, picks]

    # Scale (z-score) the data by channel type
    # XXX the z-scoring is applied outside the CV, which is not standard.
    for ch_type in ['mag', 'grad', 'eeg']:
        if ch_type in epochs:
            # FIXME should be applied to all sort of data channels
            if ch_type == 'eeg':
                this_picks = pick_types(info, meg=False, eeg=True)
            else:
                this_picks = pick_types(info, meg=ch_type, eeg=False)
            data[:, this_picks] /= np.std(data[:, this_picks])

    # Setup cross-validation. Need to use _set_cv to deal with sklearn
    # deprecation of cv objects.
    y = epochs.events[:, 2]
    _, cv_splits = _set_cv(cv, 'classifier', X=y, y=y)

    parallel, p_func, _ = parallel_func(_run_ems, n_jobs=n_jobs)
    # FIXME this parallelization should be removed.
    #   1) it's numpy computation so it's already efficient,
    #   2) it duplicates the data in RAM,
    #   3) the computation is already super fast.
    out = parallel(p_func(_ems_diff, data, cond_idx, train, test)
                   for train, test in cv_splits)

    surrogate_trials, spatial_filter = zip(*out)
    surrogate_trials = np.array(surrogate_trials)
    spatial_filter = np.mean(spatial_filter, axis=0)

    return surrogate_trials, spatial_filter, epochs.events[:, 2]


def _ems_diff(data0, data1):
    """Compute the default diff objective function."""
    return np.mean(data0, axis=0) - np.mean(data1, axis=0)


def _run_ems(objective_function, data, cond_idx, train, test):
    """Run EMS."""
    d = objective_function(*(data[np.intersect1d(c, train)] for c in cond_idx))
    d /= np.sqrt(np.sum(d ** 2, axis=0))[None, :]
    # compute surrogates
    return np.sum(data[test[0]] * d, axis=0), d