File: test_csd.py

package info (click to toggle)
python-mne 0.8.6%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 87,892 kB
  • ctags: 6,639
  • sloc: python: 54,697; makefile: 165; sh: 15
file content (167 lines) | stat: -rw-r--r-- 7,401 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import numpy as np
from nose.tools import (assert_raises, assert_equal, assert_almost_equal,
                        assert_true)
from numpy.testing import assert_array_equal
from os import path as op
import warnings

import mne

from mne.io import Raw
from mne.utils import sum_squared
from mne.time_frequency import compute_epochs_csd, induced_power

warnings.simplefilter('always')
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
event_fname = op.join(base_dir, 'test-eve.fif')


def _get_data():
    # Read raw data
    raw = Raw(raw_fname)
    raw.info['bads'] = ['MEG 2443', 'EEG 053']  # 2 bads channels

    # Set picks
    picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
                                stim=False, exclude='bads')

    # Read several epochs
    event_id, tmin, tmax = 1, -0.2, 0.5
    events = mne.read_events(event_fname)[0:100]
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
                        picks=picks, baseline=(None, 0), preload=True,
                        reject=dict(grad=4000e-13, mag=4e-12))

    # Create an epochs object with one epoch and one channel of artificial data
    event_id, tmin, tmax = 1, 0.0, 1.0
    epochs_sin = mne.Epochs(raw, events[0:5], event_id, tmin, tmax, proj=True,
                            picks=[0], baseline=(None, 0), preload=True,
                            reject=dict(grad=4000e-13))
    freq = 10
    epochs_sin._data = np.sin(2 * np.pi * freq
                              * epochs_sin.times)[None, None, :]
    return epochs, epochs_sin


def test_compute_epochs_csd():
    """Test computing cross-spectral density from epochs
    """
    epochs, epochs_sin = _get_data()
    # Check that wrong parameters are recognized
    assert_raises(ValueError, compute_epochs_csd, epochs, mode='notamode')
    assert_raises(ValueError, compute_epochs_csd, epochs, fmin=20, fmax=10)
    assert_raises(ValueError, compute_epochs_csd, epochs, fmin=20, fmax=20.1)
    assert_raises(ValueError, compute_epochs_csd, epochs, tmin=0.15, tmax=0.1)
    assert_raises(ValueError, compute_epochs_csd, epochs, tmin=0, tmax=10)
    assert_raises(ValueError, compute_epochs_csd, epochs, tmin=10, tmax=11)

    data_csd_mt = compute_epochs_csd(epochs, mode='multitaper', fmin=8,
                                     fmax=12, tmin=0.04, tmax=0.15)
    data_csd_fourier = compute_epochs_csd(epochs, mode='fourier', fmin=8,
                                          fmax=12, tmin=0.04, tmax=0.15)

    # Check shape of the CSD matrix
    n_chan = len(data_csd_mt.ch_names)
    assert_equal(data_csd_mt.data.shape, (n_chan, n_chan))
    assert_equal(data_csd_fourier.data.shape, (n_chan, n_chan))

    # Check if the CSD matrix is hermitian
    assert_array_equal(np.tril(data_csd_mt.data).T.conj(),
                       np.triu(data_csd_mt.data))
    assert_array_equal(np.tril(data_csd_fourier.data).T.conj(),
                       np.triu(data_csd_fourier.data))

    # Computing induced power for comparison
    epochs.crop(tmin=0.04, tmax=0.15)
    with warnings.catch_warnings(record=True):  # deprecation
        warnings.simplefilter('always')
        power, _ = induced_power(epochs.get_data(), epochs.info['sfreq'], [10],
                                 n_cycles=0.6)
    power = np.mean(power, 2)

    # Maximum PSD should occur for specific channel
    max_ch_power = power.argmax()
    max_ch_mt = data_csd_mt.data.diagonal().argmax()
    max_ch_fourier = data_csd_fourier.data.diagonal().argmax()
    assert_equal(max_ch_mt, max_ch_power)
    assert_equal(max_ch_fourier, max_ch_power)

    # Maximum CSD should occur for specific channel
    ch_csd_mt = [np.abs(data_csd_mt.data[max_ch_power][i])
                 if i != max_ch_power else 0 for i in range(n_chan)]
    max_ch_csd_mt = np.argmax(ch_csd_mt)
    ch_csd_fourier = [np.abs(data_csd_fourier.data[max_ch_power][i])
                      if i != max_ch_power else 0 for i in range(n_chan)]
    max_ch_csd_fourier = np.argmax(ch_csd_fourier)
    assert_equal(max_ch_csd_mt, max_ch_csd_fourier)

    # Check a list of CSD matrices is returned for multiple frequencies within
    # a given range when fsum=False
    csd_fsum = compute_epochs_csd(epochs, mode='fourier', fmin=8, fmax=20,
                                  fsum=True)
    csds = compute_epochs_csd(epochs, mode='fourier', fmin=8, fmax=20,
                              fsum=False)
    freqs = [csd.frequencies[0] for csd in csds]

    csd_sum = np.zeros_like(csd_fsum.data)
    for csd in csds:
        csd_sum += csd.data

    assert(len(csds) == 2)
    assert(len(csd_fsum.frequencies) == 2)
    assert_array_equal(csd_fsum.frequencies, freqs)
    assert_array_equal(csd_fsum.data, csd_sum)


def test_compute_epochs_csd_on_artificial_data():
    """Test computing CSD on artificial data
    """
    epochs, epochs_sin = _get_data()
    sfreq = epochs_sin.info['sfreq']

    # Computing signal power in the time domain
    signal_power = sum_squared(epochs_sin._data)
    signal_power_per_sample = signal_power / len(epochs_sin.times)

    # Computing signal power in the frequency domain
    data_csd_fourier = compute_epochs_csd(epochs_sin, mode='fourier')
    data_csd_mt = compute_epochs_csd(epochs_sin, mode='multitaper')
    fourier_power = np.abs(data_csd_fourier.data[0, 0]) * sfreq
    mt_power = np.abs(data_csd_mt.data[0, 0]) * sfreq
    assert_true(abs(fourier_power - signal_power) <= 0.5)
    assert_true(abs(mt_power - signal_power) <= 1)

    # Power per sample should not depend on time window length
    for tmax in [0.2, 0.4, 0.6, 0.8]:
        for add_n_fft in [30, 0, 30]:
            t_mask = (epochs_sin.times >= 0) & (epochs_sin.times <= tmax)
            n_samples = sum(t_mask)
            n_fft = n_samples + add_n_fft

            data_csd_fourier = compute_epochs_csd(epochs_sin, mode='fourier',
                                                  tmin=None, tmax=tmax, fmin=0,
                                                  fmax=np.inf, n_fft=n_fft)
            fourier_power_per_sample = np.abs(data_csd_fourier.data[0, 0]) *\
                sfreq / data_csd_fourier.n_fft
            assert_true(abs(signal_power_per_sample -
                            fourier_power_per_sample) < 0.003)
        # Power per sample should not depend on number of tapers
        for n_tapers in [1, 2, 3, 5]:
            for add_n_fft in [30, 0, 30]:
                mt_bandwidth = sfreq / float(n_samples) * (n_tapers + 1)
                data_csd_mt = compute_epochs_csd(epochs_sin, mode='multitaper',
                                                 tmin=None, tmax=tmax, fmin=0,
                                                 fmax=np.inf,
                                                 mt_bandwidth=mt_bandwidth,
                                                 n_fft=n_fft)
                mt_power_per_sample = np.abs(data_csd_mt.data[0, 0]) *\
                    sfreq / data_csd_mt.n_fft
                # The estimate of power gets worse for small time windows when
                # more tapers are used
                if n_tapers == 5 and tmax == 0.2:
                    delta = 0.05
                else:
                    delta = 0.004
                assert_true(abs(signal_power_per_sample - mt_power_per_sample)
                            < delta)