File: common.py

package info (click to toggle)
python-mne 0.17%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 95,104 kB
  • sloc: python: 110,639; makefile: 222; sh: 15
file content (106 lines) | stat: -rw-r--r-- 4,529 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Authors: Eric Larson <larson.eric.d@gmail.com>
#
# License: BSD (3-clause)

import numpy as np
from numpy.testing import assert_allclose, assert_equal, assert_array_equal

from scipy import linalg

from .. import pick_types, Evoked
from ..io import BaseRaw
from ..io.constants import FIFF
from ..bem import fit_sphere_to_headshape


def _get_data(x, ch_idx):
    """Get the (n_ch, n_times) data array."""
    if isinstance(x, BaseRaw):
        return x[ch_idx][0]
    elif isinstance(x, Evoked):
        return x.data[ch_idx]


def _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind='MEG'):
    """Check the SNR of a set of channels."""
    actual_data = _get_data(actual, picks)
    desired_data = _get_data(desired, picks)
    bench_rms = np.sqrt(np.mean(desired_data * desired_data, axis=1))
    error = actual_data - desired_data
    error_rms = np.sqrt(np.mean(error * error, axis=1))
    np.clip(error_rms, 1e-60, np.inf, out=error_rms)  # avoid division by zero
    snrs = bench_rms / error_rms
    # min tol
    snr = snrs.min()
    bad_count = (snrs < min_tol).sum()
    msg = ' (%s)' % msg if msg != '' else msg
    assert bad_count == 0, ('SNR (worst %0.2f) < %0.2f for %s/%s '
                            'channels%s' % (snr, min_tol, bad_count,
                                            len(picks), msg))
    # median tol
    snr = np.median(snrs)
    assert snr >= med_tol, ('%s SNR median %0.2f < %0.2f%s'
                            % (kind, snr, med_tol, msg))


def assert_meg_snr(actual, desired, min_tol, med_tol=500., chpi_med_tol=500.,
                   msg=None):
    """Assert channel SNR of a certain level.

    Mostly useful for operations like Maxwell filtering that modify
    MEG channels while leaving EEG and others intact.
    """
    picks = pick_types(desired.info, meg=True, exclude=[])
    picks_desired = pick_types(desired.info, meg=True, exclude=[])
    assert_array_equal(picks, picks_desired, err_msg='MEG pick mismatch')
    chpis = pick_types(actual.info, meg=False, chpi=True, exclude=[])
    chpis_desired = pick_types(desired.info, meg=False, chpi=True, exclude=[])
    if chpi_med_tol is not None:
        assert_array_equal(chpis, chpis_desired, err_msg='cHPI pick mismatch')
    others = np.setdiff1d(np.arange(len(actual.ch_names)),
                          np.concatenate([picks, chpis]))
    others_desired = np.setdiff1d(np.arange(len(desired.ch_names)),
                                  np.concatenate([picks_desired,
                                                  chpis_desired]))
    assert_array_equal(others, others_desired, err_msg='Other pick mismatch')
    if len(others) > 0:  # if non-MEG channels present
        assert_allclose(_get_data(actual, others),
                        _get_data(desired, others), atol=1e-11, rtol=1e-5,
                        err_msg='non-MEG channel mismatch')
    _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind='MEG')
    if chpi_med_tol is not None and len(chpis) > 0:
        _check_snr(actual, desired, chpis, 0., chpi_med_tol, msg, kind='cHPI')


def assert_snr(actual, desired, tol):
    """Assert actual and desired arrays are within some SNR tolerance."""
    snr = (linalg.norm(desired, ord='fro') /
           linalg.norm(desired - actual, ord='fro'))
    assert snr >= tol, '%f < %f' % (snr, tol)


def _dig_sort_key(dig):
    """Sort dig keys."""
    return (dig['kind'], dig['ident'])


def assert_dig_allclose(info_py, info_bin, limit=None):
    """Assert dig allclose."""
    # test dig positions
    dig_py = sorted(info_py['dig'], key=_dig_sort_key)
    dig_bin = sorted(info_bin['dig'], key=_dig_sort_key)
    assert len(dig_py) == len(dig_bin)
    for ii, (d_py, d_bin) in enumerate(zip(dig_py[:limit], dig_bin[:limit])):
        for key in ('ident', 'kind', 'coord_frame'):
            assert_equal(d_py[key], d_bin[key])
        assert_allclose(d_py['r'], d_bin['r'], rtol=1e-5, atol=1e-5,
                        err_msg='Failure on %s:\n%s\n%s'
                        % (ii, d_py['r'], d_bin['r']))
    if any(d['kind'] == FIFF.FIFFV_POINT_EXTRA for d in dig_py):
        r_bin, o_head_bin, o_dev_bin = fit_sphere_to_headshape(
            info_bin, units='m', verbose='error')
        r_py, o_head_py, o_dev_py = fit_sphere_to_headshape(
            info_py, units='m', verbose='error')
        assert_allclose(r_py, r_bin, atol=1e-6)
        assert_allclose(o_dev_py, o_dev_bin, rtol=1e-5, atol=1e-6)
        assert_allclose(o_head_py, o_head_bin, rtol=1e-5, atol=1e-6)