1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
|
# Authors: Eric Larson <larson.eric.d@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
from numpy.testing import assert_allclose, assert_equal, assert_array_equal
from scipy import linalg
from .. import pick_types, Evoked
from ..io import BaseRaw
from ..io.constants import FIFF
from ..bem import fit_sphere_to_headshape
def _get_data(x, ch_idx):
"""Get the (n_ch, n_times) data array."""
if isinstance(x, BaseRaw):
return x[ch_idx][0]
elif isinstance(x, Evoked):
return x.data[ch_idx]
def _check_snr(actual, desired, picks, min_tol, med_tol, msg, kind='MEG'):
"""Check the SNR of a set of channels."""
actual_data = _get_data(actual, picks)
desired_data = _get_data(desired, picks)
bench_rms = np.sqrt(np.mean(desired_data * desired_data, axis=1))
error = actual_data - desired_data
error_rms = np.sqrt(np.mean(error * error, axis=1))
np.clip(error_rms, 1e-60, np.inf, out=error_rms) # avoid division by zero
snrs = bench_rms / error_rms
# min tol
snr = snrs.min()
bad_count = (snrs < min_tol).sum()
msg = ' (%s)' % msg if msg != '' else msg
assert bad_count == 0, ('SNR (worst %0.2f) < %0.2f for %s/%s '
'channels%s' % (snr, min_tol, bad_count,
len(picks), msg))
# median tol
snr = np.median(snrs)
assert snr >= med_tol, ('%s SNR median %0.2f < %0.2f%s'
% (kind, snr, med_tol, msg))
def assert_meg_snr(actual, desired, min_tol, med_tol=500., chpi_med_tol=500.,
msg=None):
"""Assert channel SNR of a certain level.
Mostly useful for operations like Maxwell filtering that modify
MEG channels while leaving EEG and others intact.
"""
picks = pick_types(desired.info, meg=True, exclude=[])
picks_desired = pick_types(desired.info, meg=True, exclude=[])
assert_array_equal(picks, picks_desired, err_msg='MEG pick mismatch')
chpis = pick_types(actual.info, meg=False, chpi=True, exclude=[])
chpis_desired = pick_types(desired.info, meg=False, chpi=True, exclude=[])
if chpi_med_tol is not None:
assert_array_equal(chpis, chpis_desired, err_msg='cHPI pick mismatch')
others = np.setdiff1d(np.arange(len(actual.ch_names)),
np.concatenate([picks, chpis]))
others_desired = np.setdiff1d(np.arange(len(desired.ch_names)),
np.concatenate([picks_desired,
chpis_desired]))
assert_array_equal(others, others_desired, err_msg='Other pick mismatch')
if len(others) > 0: # if non-MEG channels present
assert_allclose(_get_data(actual, others),
_get_data(desired, others), atol=1e-11, rtol=1e-5,
err_msg='non-MEG channel mismatch')
_check_snr(actual, desired, picks, min_tol, med_tol, msg, kind='MEG')
if chpi_med_tol is not None and len(chpis) > 0:
_check_snr(actual, desired, chpis, 0., chpi_med_tol, msg, kind='cHPI')
def assert_snr(actual, desired, tol):
"""Assert actual and desired arrays are within some SNR tolerance."""
snr = (linalg.norm(desired, ord='fro') /
linalg.norm(desired - actual, ord='fro'))
assert snr >= tol, '%f < %f' % (snr, tol)
def _dig_sort_key(dig):
"""Sort dig keys."""
return (dig['kind'], dig['ident'])
def assert_dig_allclose(info_py, info_bin, limit=None):
"""Assert dig allclose."""
# test dig positions
dig_py = sorted(info_py['dig'], key=_dig_sort_key)
dig_bin = sorted(info_bin['dig'], key=_dig_sort_key)
assert len(dig_py) == len(dig_bin)
for ii, (d_py, d_bin) in enumerate(zip(dig_py[:limit], dig_bin[:limit])):
for key in ('ident', 'kind', 'coord_frame'):
assert_equal(d_py[key], d_bin[key])
assert_allclose(d_py['r'], d_bin['r'], rtol=1e-5, atol=1e-5,
err_msg='Failure on %s:\n%s\n%s'
% (ii, d_py['r'], d_bin['r']))
if any(d['kind'] == FIFF.FIFFV_POINT_EXTRA for d in dig_py):
r_bin, o_head_bin, o_dev_bin = fit_sphere_to_headshape(
info_bin, units='m', verbose='error')
r_py, o_head_py, o_dev_py = fit_sphere_to_headshape(
info_py, units='m', verbose='error')
assert_allclose(r_py, r_bin, atol=1e-6)
assert_allclose(o_dev_py, o_dev_bin, rtol=1e-5, atol=1e-6)
assert_allclose(o_head_py, o_head_bin, rtol=1e-5, atol=1e-6)
|