1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
"""
==============================================
Real-time feedback for decoding :: Server Side
==============================================
This example demonstrates how to setup a real-time feedback
mechanism using StimServer and StimClient.
The idea here is to display future stimuli for the class which
is predicted less accurately. This allows on-demand adaptation
of the stimuli depending on the needs of the classifier.
To run this example, open ipython in two separate terminals.
In the first, run rt_feedback_server.py and then wait for the
message
RtServer: Start
Once that appears, run rt_feedback_client.py in the other terminal
and the feedback script should start.
All brain responses are simulated from a fiff file to make it easy
to test. However, it should be possible to adapt this script
for a real experiment.
"""
print(__doc__)
# Author: Mainak Jas <mainak@neuro.hut.fi>
#
# License: BSD (3-clause)
import time
import mne
import numpy as np
import matplotlib.pyplot as plt
from mne.datasets import sample
from mne.realtime import StimServer
from mne.realtime import MockRtClient
from mne.decoding import ConcatenateChannels, FilterEstimator
from sklearn import preprocessing
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix
# Load fiff file to simulate data
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = mne.io.Raw(raw_fname, preload=True)
# Instantiating stimulation server
# The with statement is necessary to ensure a clean exit
with StimServer('localhost', port=4218) as stim_server:
# The channels to be used while decoding
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
stim=True, exclude=raw.info['bads'])
rt_client = MockRtClient(raw)
# Constructing the pipeline for classification
filt = FilterEstimator(raw.info, 1, 40)
scaler = preprocessing.StandardScaler()
concatenator = ConcatenateChannels()
clf = SVC(C=1, kernel='linear')
concat_classifier = Pipeline([('filter', filt), ('concat', concatenator),
('scaler', scaler), ('svm', clf)])
stim_server.start(verbose=True)
# Just some initially decided events to be simulated
# Rest will decided on the fly
ev_list = [4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4]
score_c1, score_c2, score_x = [], [], []
for ii in range(50):
# Tell the stim_client about the next stimuli
stim_server.add_trigger(ev_list[ii])
# Collecting data
if ii == 0:
X = rt_client.get_event_data(event_id=ev_list[ii], tmin=-0.2,
tmax=0.5, picks=picks,
stim_channel='STI 014')[None, ...]
y = ev_list[ii]
else:
X_temp = rt_client.get_event_data(event_id=ev_list[ii], tmin=-0.2,
tmax=0.5, picks=picks,
stim_channel='STI 014')
X_temp = X_temp[np.newaxis, ...]
X = np.concatenate((X, X_temp), axis=0)
time.sleep(1) # simulating the isi
y = np.append(y, ev_list[ii])
# Start decoding after collecting sufficient data
if ii >= 10:
# Now start doing rtfeedback
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.2,
random_state=7)
y_pred = concat_classifier.fit(X_train, y_train).predict(X_test)
cm = confusion_matrix(y_test, y_pred)
score_c1.append(float(cm[0, 0]) / sum(cm, 1)[0] * 100)
score_c2.append(float(cm[1, 1]) / sum(cm, 1)[1] * 100)
# do something if one class is decoded better than the other
if score_c1[-1] < score_c2[-1]:
print("We decoded class RV better than class LV")
ev_list.append(3) # adding more LV to future simulated data
else:
print("We decoded class LV better than class RV")
ev_list.append(4) # adding more RV to future simulated data
# Clear the figure
plt.clf()
# The x-axis for the plot
score_x.append(ii)
# Now plot the accuracy
plt.plot(score_x[-5:], score_c1[-5:])
plt.hold(True)
plt.plot(score_x[-5:], score_c2[-5:])
plt.xlabel('Trials')
plt.ylabel('Classification score (% correct)')
plt.title('Real-time feedback')
plt.ylim([0, 100])
plt.xticks(score_x[-5:])
plt.legend(('LV', 'RV'), loc='upper left')
plt.show()
|