1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
|
# Author: Mainak Jas
#
# License: BSD (3-clause)
import copy
import re
import threading
import time
import numpy as np
from ..io import _empty_info
from ..io.pick import pick_info
from ..io.constants import FIFF
from ..epochs import EpochsArray
from ..utils import logger, warn
from ..externals.FieldTrip import Client as FtClient
def _buffer_recv_worker(ft_client):
"""Worker thread that constantly receives buffers."""
try:
for raw_buffer in ft_client.iter_raw_buffers():
ft_client._push_raw_buffer(raw_buffer)
except RuntimeError as err:
# something is wrong, the server stopped (or something)
ft_client._recv_thread = None
logger.error('Buffer receive thread stopped: %s' % err)
class FieldTripClient(object):
"""Realtime FieldTrip client.
Parameters
----------
info : dict | None
The measurement info read in from a file. If None, it is guessed from
the Fieldtrip Header object.
host : str
Hostname (or IP address) of the host where Fieldtrip buffer is running.
port : int
Port to use for the connection.
wait_max : float
Maximum time (in seconds) to wait for Fieldtrip buffer to start
tmin : float | None
Time instant to start receiving buffers. If None, start from the latest
samples available.
tmax : float
Time instant to stop receiving buffers.
buffer_size : int
Size of each buffer in terms of number of samples.
verbose : bool, str, int, or None
Log verbosity (see :func:`mne.verbose` and
:ref:`Logging documentation <tut_logging>` for more).
"""
def __init__(self, info=None, host='localhost', port=1972, wait_max=30,
tmin=None, tmax=np.inf, buffer_size=1000,
verbose=None): # noqa: D102
self.verbose = verbose
self.info = info
self.wait_max = wait_max
self.tmin = tmin
self.tmax = tmax
self.buffer_size = buffer_size
self.host = host
self.port = port
self._recv_thread = None
self._recv_callbacks = list()
def __enter__(self): # noqa: D105
# instantiate Fieldtrip client and connect
self.ft_client = FtClient()
# connect to FieldTrip buffer
logger.info("FieldTripClient: Waiting for server to start")
start_time, current_time = time.time(), time.time()
success = False
while current_time < (start_time + self.wait_max):
try:
self.ft_client.connect(self.host, self.port)
logger.info("FieldTripClient: Connected")
success = True
break
except Exception:
current_time = time.time()
time.sleep(0.1)
if not success:
raise RuntimeError('Could not connect to FieldTrip Buffer')
# retrieve header
logger.info("FieldTripClient: Retrieving header")
start_time, current_time = time.time(), time.time()
while current_time < (start_time + self.wait_max):
self.ft_header = self.ft_client.getHeader()
if self.ft_header is None:
current_time = time.time()
time.sleep(0.1)
else:
break
if self.ft_header is None:
raise RuntimeError('Failed to retrieve Fieldtrip header!')
else:
logger.info("FieldTripClient: Header retrieved")
self.info = self._guess_measurement_info()
self.ch_names = self.ft_header.labels
# find start and end samples
sfreq = self.info['sfreq']
if self.tmin is None:
self.tmin_samp = max(0, self.ft_header.nSamples - 1)
else:
self.tmin_samp = int(round(sfreq * self.tmin))
if self.tmax != np.inf:
self.tmax_samp = int(round(sfreq * self.tmax))
else:
self.tmax_samp = np.iinfo(np.uint32).max
return self
def __exit__(self, type, value, traceback): # noqa: D105
self.ft_client.disconnect()
def _guess_measurement_info(self):
"""Create a minimal Info dictionary for epoching, averaging, etc."""
if self.info is None:
warn('Info dictionary not provided. Trying to guess it from '
'FieldTrip Header object')
info = _empty_info(self.ft_header.fSample) # create info
# modify info attributes according to the FieldTrip Header object
info['comps'] = list()
info['projs'] = list()
info['bads'] = list()
# channel dictionary list
info['chs'] = []
# unrecognized channels
chs_unknown = []
for idx, ch in enumerate(self.ft_header.labels):
this_info = dict()
this_info['scanno'] = idx
# extract numerical part of channel name
this_info['logno'] = \
int(re.findall(r'[^\W\d_]+|\d+', ch)[-1])
if ch.startswith('EEG'):
this_info['kind'] = FIFF.FIFFV_EEG_CH
elif ch.startswith('MEG'):
this_info['kind'] = FIFF.FIFFV_MEG_CH
elif ch.startswith('MCG'):
this_info['kind'] = FIFF.FIFFV_MCG_CH
elif ch.startswith('EOG'):
this_info['kind'] = FIFF.FIFFV_EOG_CH
elif ch.startswith('EMG'):
this_info['kind'] = FIFF.FIFFV_EMG_CH
elif ch.startswith('STI'):
this_info['kind'] = FIFF.FIFFV_STIM_CH
elif ch.startswith('ECG'):
this_info['kind'] = FIFF.FIFFV_ECG_CH
elif ch.startswith('MISC'):
this_info['kind'] = FIFF.FIFFV_MISC_CH
elif ch.startswith('SYS'):
this_info['kind'] = FIFF.FIFFV_SYST_CH
else:
# cannot guess channel type, mark as MISC and warn later
this_info['kind'] = FIFF.FIFFV_MISC_CH
chs_unknown.append(ch)
# Set coil_type (does FT supply this information somehow?)
this_info['coil_type'] = FIFF.FIFFV_COIL_NONE
# Fieldtrip already does calibration
this_info['range'] = 1.0
this_info['cal'] = 1.0
this_info['ch_name'] = ch
this_info['loc'] = np.zeros(12)
if ch.startswith('EEG'):
this_info['coord_frame'] = FIFF.FIFFV_COORD_HEAD
elif ch.startswith('MEG'):
this_info['coord_frame'] = FIFF.FIFFV_COORD_DEVICE
else:
this_info['coord_frame'] = FIFF.FIFFV_COORD_UNKNOWN
if ch.startswith('MEG') and ch.endswith('1'):
this_info['unit'] = FIFF.FIFF_UNIT_T
elif ch.startswith('MEG') and (ch.endswith('2') or
ch.endswith('3')):
this_info['unit'] = FIFF.FIFF_UNIT_T_M
else:
this_info['unit'] = FIFF.FIFF_UNIT_V
this_info['unit_mul'] = 0
info['chs'].append(this_info)
info._update_redundant()
info._check_consistency()
if chs_unknown:
msg = ('Following channels in the FieldTrip header were '
'unrecognized and marked as MISC: ')
warn(msg + ', '.join(chs_unknown))
else:
# XXX: the data in real-time mode and offline mode
# does not match unless this is done
self.info['projs'] = list()
# FieldTrip buffer already does the calibration
for this_info in self.info['chs']:
this_info['range'] = 1.0
this_info['cal'] = 1.0
this_info['unit_mul'] = 0
info = copy.deepcopy(self.info)
return info
def get_measurement_info(self):
"""Return the measurement info.
Returns
-------
self.info : dict
The measurement info.
"""
return self.info
def get_data_as_epoch(self, n_samples=1024, picks=None):
"""Return last n_samples from current time.
Parameters
----------
n_samples : int
Number of samples to fetch.
picks : array-like of int | None
If None all channels are kept
otherwise the channels indices in picks are kept.
Returns
-------
epoch : instance of Epochs
The samples fetched as an Epochs object.
See Also
--------
mne.Epochs.iter_evoked
"""
ft_header = self.ft_client.getHeader()
last_samp = ft_header.nSamples - 1
start = last_samp - n_samples + 1
stop = last_samp
events = np.expand_dims(np.array([start, 1, 1]), axis=0)
# get the data
data = self.ft_client.getData([start, stop]).transpose()
# create epoch from data
info = self.info
if picks is not None:
info = pick_info(info, picks)
else:
picks = range(info['nchan'])
epoch = EpochsArray(data[picks][np.newaxis], info, events)
return epoch
def register_receive_callback(self, callback):
"""Register a raw buffer receive callback.
Parameters
----------
callback : callable
The callback. The raw buffer is passed as the first parameter
to callback.
"""
if callback not in self._recv_callbacks:
self._recv_callbacks.append(callback)
def unregister_receive_callback(self, callback):
"""Unregister a raw buffer receive callback.
Parameters
----------
callback : callable
The callback to unregister.
"""
if callback in self._recv_callbacks:
self._recv_callbacks.remove(callback)
def _push_raw_buffer(self, raw_buffer):
"""Push raw buffer to clients using callbacks."""
for callback in self._recv_callbacks:
callback(raw_buffer)
def start_receive_thread(self, nchan):
"""Start the receive thread.
If the measurement has not been started, it will also be started.
Parameters
----------
nchan : int
The number of channels in the data.
"""
if self._recv_thread is None:
self._recv_thread = threading.Thread(target=_buffer_recv_worker,
args=(self, ))
self._recv_thread.daemon = True
self._recv_thread.start()
def stop_receive_thread(self, stop_measurement=False):
"""Stop the receive thread.
Parameters
----------
stop_measurement : bool
unused, for compatibility.
"""
self._recv_thread = None
def iter_raw_buffers(self):
"""Return an iterator over raw buffers.
Returns
-------
raw_buffer : generator
Generator for iteration over raw buffers.
"""
# self.tmax_samp should be included
iter_times = list(zip(
list(range(self.tmin_samp, self.tmax_samp, self.buffer_size)),
list(range(self.tmin_samp + self.buffer_size,
self.tmax_samp + 1, self.buffer_size))))
last_iter_sample = iter_times[-1][1] if iter_times else self.tmin_samp
if last_iter_sample < self.tmax_samp + 1:
iter_times.append((last_iter_sample, self.tmax_samp + 1))
for ii, (start, stop) in enumerate(iter_times):
# wait for correct number of samples to be available
self.ft_client.wait(stop, np.iinfo(np.uint32).max,
np.iinfo(np.uint32).max)
# get the samples (stop index is inclusive)
raw_buffer = self.ft_client.getData([start, stop - 1]).transpose()
yield raw_buffer
if self._recv_thread != threading.current_thread():
# stop_receive_thread has been called
break
|