File: plot_compute_rt_decoder.py

package info (click to toggle)
python-mne 0.8.6%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 87,892 kB
  • ctags: 6,639
  • sloc: python: 54,697; makefile: 165; sh: 15
file content (111 lines) | stat: -rw-r--r-- 3,354 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
=======================
Decoding real-time data
=======================

Supervised machine learning applied to MEG data in sensor space.
Here the classifier is updated every 5 trials and the decoding
accuracy is plotted
"""
# Authors: Mainak Jas <mainak@neuro.hut.fi>
#
# License: BSD (3-clause)

print(__doc__)

import time

import mne
from mne.realtime import MockRtClient, RtEpochs
from mne.datasets import sample

import numpy as np
import matplotlib.pyplot as plt

# Fiff file to simulate the realtime client
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = mne.io.Raw(raw_fname, preload=True)

tmin, tmax = -0.2, 0.5
event_id = dict(aud_l=1, vis_l=3)

tr_percent = 60  # Training percentage
min_trials = 10  # minimum trials after which decoding should start

# select gradiometers
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
                       stim=True, exclude=raw.info['bads'])

# create the mock-client object
rt_client = MockRtClient(raw)

# create the real-time epochs object
rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks, decim=1,
                     reject=dict(grad=4000e-13, eog=150e-6))

# start the acquisition
rt_epochs.start()

# send raw buffers
rt_client.send_data(rt_epochs, picks, tmin=0, tmax=90, buffer_size=1000)

# Decoding in sensor space using a linear SVM
n_times = len(rt_epochs.times)

from sklearn import preprocessing
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.cross_validation import cross_val_score, ShuffleSplit

from mne.decoding import ConcatenateChannels, FilterEstimator

scores_x, scores, std_scores = [], [], []

filt = FilterEstimator(rt_epochs.info, 1, 40)
scaler = preprocessing.StandardScaler()
concatenator = ConcatenateChannels()
clf = SVC(C=1, kernel='linear')

concat_classifier = Pipeline([('filter', filt), ('concat', concatenator),
                              ('scaler', scaler), ('svm', clf)])

for ev_num, ev in enumerate(rt_epochs.iter_evoked()):

    print("Just got epoch %d" % (ev_num + 1))

    if ev_num == 0:
        X = ev.data[None, ...]
        y = int(ev.comment)
    else:
        X = np.concatenate((X, ev.data[None, ...]), axis=0)
        y = np.append(y, int(ev.comment))

    if ev_num >= min_trials:

        cv = ShuffleSplit(len(y), 5, test_size=0.2, random_state=42)
        scores_t = cross_val_score(concat_classifier, X, y, cv=cv,
                                   n_jobs=1) * 100

        std_scores.append(scores_t.std())
        scores.append(scores_t.mean())
        scores_x.append(ev_num)

        # Plot accuracy
        plt.clf()

        plt.plot(scores_x[-5:], scores[-5:], '+', label="Classif. score")
        plt.hold(True)
        plt.plot(scores_x[-5:], scores[-5:])
        plt.axhline(50, color='k', linestyle='--', label="Chance level")
        hyp_limits = (np.asarray(scores[-5:]) - np.asarray(std_scores[-5:]),
                      np.asarray(scores[-5:]) + np.asarray(std_scores[-5:]))
        plt.fill_between(scores_x[-5:], hyp_limits[0], y2=hyp_limits[1],
                         color='b', alpha=0.5)
        plt.xlabel('Trials')
        plt.ylabel('Classification score (% correct)')
        plt.ylim([30, 105])
        plt.title('Real-time decoding')
        plt.show()

        time.sleep(0.1)