File: plot_decoding_time_generalization.py

package info (click to toggle)
python-mne 0.8.6%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 87,892 kB
  • ctags: 6,639
  • sloc: python: 54,697; makefile: 165; sh: 15
file content (91 lines) | stat: -rw-r--r-- 3,164 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
========================================================
Decoding sensor space data with over-time generalization
========================================================

This example runs the analysis computed in:

Jean-Remi King, Alexandre Gramfort, Aaron Schurger, Lionel Naccache
and Stanislas Dehaene, "Two distinct dynamic modes subtend the detection of
unexpected sounds", PLOS ONE, 2013

The idea is to learn at one time instant and assess if the decoder
can predict accurately over time.
"""
print(__doc__)

# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt

import mne
from mne.datasets import spm_face
from mne.decoding import time_generalization

data_path = spm_face.data_path()

###############################################################################
# Load and filter data, set up epochs

raw_fname = data_path + '/MEG/spm/SPM_CTF_MEG_example_faces%d_3D_raw.fif'

raw = mne.io.Raw(raw_fname % 1, preload=True)  # Take first run
raw.append(mne.io.Raw(raw_fname % 2, preload=True))  # Take second run too

picks = mne.pick_types(raw.info, meg=True, exclude='bads')
raw.filter(1, 45, method='iir')

events = mne.find_events(raw, stim_channel='UPPT001')
event_id = {"faces": 1, "scrambled": 2}
tmin, tmax = -0.1, 0.5

# Set up pick list
picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=True, eog=True,
                       ref_meg=False, exclude='bads')

# Read epochs
decim = 4  # decimate to make the example faster to run
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
                    picks=picks, baseline=None, preload=True,
                    reject=dict(mag=1.5e-12), decim=decim)

epochs_list = [epochs[k] for k in event_id]
mne.epochs.equalize_epoch_counts(epochs_list)

###############################################################################
# Run decoding

# Compute Area Under the Curver (AUC) Receiver Operator Curve (ROC) score
# of time generalization. A perfect decoding would lead to AUCs of 1.
# Chance level is at 0.5.
# The default classifier is a linear SVM (C=1) after feature scaling.
scores = time_generalization(epochs_list, clf=None, cv=5, scoring="roc_auc",
                             shuffle=True, n_jobs=2)

###############################################################################
# Now visualize
times = 1e3 * epochs.times  # convert times to ms

plt.figure()
plt.imshow(scores, interpolation='nearest', origin='lower',
           extent=[times[0], times[-1], times[0], times[-1]],
           vmin=0.1, vmax=0.9, cmap='RdBu_r')
plt.xlabel('Times Test (ms)')
plt.ylabel('Times Train (ms)')
plt.title('Time generalization (%s vs. %s)' % tuple(event_id.keys()))
plt.axvline(0, color='k')
plt.axhline(0, color='k')
plt.colorbar()

plt.figure()
plt.plot(times, np.diag(scores), label="Classif. score")
plt.axhline(0.5, color='k', linestyle='--', label="Chance level")
plt.axvline(0, color='r', label='stim onset')
plt.legend()
plt.xlabel('Time (ms)')
plt.ylabel('ROC classification score')
plt.title('Decoding (%s vs. %s)' % tuple(event_id.keys()))
plt.show()