File: test_fieldtrip_client.py

package info (click to toggle)
python-mne 0.17%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 95,104 kB
  • sloc: python: 110,639; makefile: 222; sh: 15
file content (174 lines) | stat: -rw-r--r-- 6,445 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Author: Mainak Jas <mainak@neuro.hut.fi>
#
# License: BSD (3-clause)

import time
import os
import threading
import subprocess
import os.path as op
import contextlib
import socket

import numpy as np
from numpy.testing import assert_array_equal, assert_allclose
import pytest

from mne import Epochs, find_events, pick_types
from mne.io import read_raw_fif
from mne.utils import requires_neuromag2ft
from mne.utils import run_tests_if_main
from mne.realtime import FieldTripClient, RtEpochs
from mne.externals.six.moves import queue

from mne.realtime.tests.test_mockclient import _call_base_epochs_public_api

# Set our plotters to test mode
import matplotlib
matplotlib.use('Agg')  # for testing don't use X server

base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.realpath(op.join(base_dir, 'test_raw.fif'))


@pytest.fixture
def free_tcp_port():
    """Get a free TCP port."""
    with contextlib.closing(socket.socket()) as free_socket:
        free_socket.bind(('127.0.0.1', 0))
        return free_socket.getsockname()[1]


def _run_buffer(kill_signal, server_port):
    """Start a FieldTrip Buffer from FIF file as a subprocess."""
    neuromag2ft_fname = op.realpath(op.join(os.environ['NEUROMAG2FT_ROOT'],
                                    'neuromag2ft'))
    # Works with neuromag2ft-3.0.2
    cmd = (neuromag2ft_fname, '--file', raw_fname, '--speed', '8.0',
           '--bufport', str(server_port))

    # sleep to make sure everything is setup before playback starts
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)
    # Let measurement continue for the entire duration
    kill_signal.get(timeout=40.0)
    process.terminate()
    with pytest.warns(None):  # still running
        process.stderr.close()
        process.stdout.close()
        del process


def _start_buffer_thread(buffer_port):
    """Start a FieldTrip Buffer in a background thread."""
    signal_queue = queue.Queue()
    thread = threading.Thread(target=_run_buffer,
                              args=(signal_queue, buffer_port))
    thread.daemon = True
    thread.start()
    return signal_queue


@pytest.mark.slowtest
@requires_neuromag2ft
def test_fieldtrip_rtepochs(free_tcp_port, tmpdir):
    """Test FieldTrip RtEpochs."""
    raw_tmax = 7
    raw = read_raw_fif(raw_fname, preload=True)
    raw.crop(tmin=0, tmax=raw_tmax)
    events_offline = find_events(raw, stim_channel='STI 014')
    event_id = list(np.unique(events_offline[:, 2]))
    tmin, tmax = -0.2, 0.5
    epochs_offline = Epochs(raw, events_offline, event_id=event_id,
                            tmin=tmin, tmax=tmax)
    epochs_offline.drop_bad()
    isi_max = (np.max(np.diff(epochs_offline.events[:, 0])) /
               raw.info['sfreq']) + 1.0

    kill_signal = _start_buffer_thread(free_tcp_port)

    try:
        data_rt = None
        events_ids_rt = None
        with pytest.warns(RuntimeWarning, match='Trying to guess it'):
            with FieldTripClient(host='localhost', port=free_tcp_port,
                                 tmax=raw_tmax, wait_max=2) as rt_client:
                # get measurement info guessed by MNE-Python
                raw_info = rt_client.get_measurement_info()
                assert ([ch['ch_name'] for ch in raw_info['chs']] ==
                        [ch['ch_name'] for ch in raw.info['chs']])

                # create the real-time epochs object
                epochs_rt = RtEpochs(rt_client, event_id, tmin, tmax,
                                     stim_channel='STI 014', isi_max=isi_max)
                epochs_rt.start()

                time.sleep(0.5)
                for ev_num, ev in enumerate(epochs_rt.iter_evoked()):
                    if ev_num == 0:
                        data_rt = ev.data[None, :, :]
                        events_ids_rt = int(
                            ev.comment)  # comment attribute contains event_id
                    else:
                        data_rt = np.concatenate(
                            (data_rt, ev.data[None, :, :]), axis=0)
                        events_ids_rt = np.append(events_ids_rt,
                                                  int(ev.comment))

                _call_base_epochs_public_api(epochs_rt, tmpdir)
                epochs_rt.stop(stop_receive_thread=True)

        assert_array_equal(events_ids_rt, epochs_rt.events[:, 2])
        assert_array_equal(data_rt, epochs_rt.get_data())
        assert len(epochs_rt) == len(epochs_offline)
        assert_array_equal(events_ids_rt, epochs_offline.events[:, 2])
        assert_allclose(epochs_rt.get_data(), epochs_offline.get_data(),
                        rtol=1.e-5, atol=1.e-8)  # defaults of np.isclose
    finally:
        kill_signal.put(False)  # stop the buffer


@requires_neuromag2ft
def test_fieldtrip_client(free_tcp_port):
    """Test fieldtrip_client."""
    kill_signal = _start_buffer_thread(free_tcp_port)

    time.sleep(0.5)

    try:
        # Start the FieldTrip buffer
        with pytest.warns(RuntimeWarning):
            with FieldTripClient(host='localhost', port=free_tcp_port,
                                 tmax=5, wait_max=2) as rt_client:
                tmin_samp1 = rt_client.tmin_samp

        time.sleep(1)  # Pause measurement

        # Start the FieldTrip buffer again
        with pytest.warns(RuntimeWarning):
            with FieldTripClient(host='localhost', port=free_tcp_port,
                                 tmax=5, wait_max=2) as rt_client:
                raw_info = rt_client.get_measurement_info()

                tmin_samp2 = rt_client.tmin_samp
                picks = pick_types(raw_info, meg='grad', eeg=False,
                                   stim=False, eog=False)
                epoch = rt_client.get_data_as_epoch(n_samples=5, picks=picks)
                n_channels, n_samples = epoch.get_data().shape[1:]

                epoch2 = rt_client.get_data_as_epoch(n_samples=5, picks=picks)
                n_channels2, n_samples2 = epoch2.get_data().shape[1:]

                # case of picks=None
                epoch = rt_client.get_data_as_epoch(n_samples=5)

        assert tmin_samp2 > tmin_samp1
        assert n_samples == 5
        assert n_samples2 == 5
        assert n_channels == len(picks)
        assert n_channels2 == len(picks)
    finally:
        kill_signal.put(False)  # stop the buffer


run_tests_if_main()