# -*- coding: utf-8 -*-
"""Some utility functions."""
from __future__ import print_function

# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#
# License: BSD (3-clause)

import atexit
from collections import Iterable
from contextlib import contextmanager
from distutils.version import LooseVersion
from functools import wraps
from functools import partial
import hashlib
import inspect
import json
import logging
import fnmatch

from math import log, ceil
import multiprocessing
import operator
import os
import os.path as op
import platform
import shutil
from shutil import rmtree
from string import Formatter
import subprocess
import sys
import tempfile
import time
import traceback
from unittest import SkipTest
import warnings
import webbrowser
import re

import numpy as np
from scipy import linalg, sparse

from .externals.six.moves import urllib
from .externals.six import string_types, StringIO, BytesIO, integer_types
from .externals.decorator import decorator

from .fixes import _get_args

logger = logging.getLogger('mne')  # one selection here used across mne-python
logger.propagate = False  # don't propagate (in case of multiple imports)


def _memory_usage(*args, **kwargs):
    if isinstance(args[0], tuple):
        args[0][0](*args[0][1], **args[0][2])
    elif not isinstance(args[0], int):  # can be -1 for current use
        args[0]()
    return [-1]


try:
    from memory_profiler import memory_usage
except ImportError:
    memory_usage = _memory_usage


def nottest(f):
    """Mark a function as not a test (decorator)."""
    f.__test__ = False
    return f


# # # WARNING # # #
# This list must also be updated in doc/_templates/class.rst if it is
# changed here!
_doc_special_members = ('__contains__', '__getitem__', '__iter__', '__len__',
                        '__add__', '__sub__', '__mul__', '__div__',
                        '__neg__', '__hash__')

###############################################################################
# RANDOM UTILITIES


def _get_argvalues():
    """Return all arguments (except self) and values of read_raw_xxx."""
    # call stack
    # read_raw_xxx -> EOF -> verbose() -> BaseRaw.__init__ -> get_argvalues
    frame = inspect.stack()[4][0]
    fname = frame.f_code.co_filename
    if not fnmatch.fnmatch(fname, '*/mne/io/*'):
        return None
    args, _, _, values = inspect.getargvalues(frame)
    params = dict()
    for arg in args:
        params[arg] = values[arg]
    params.pop('self', None)
    return params


def _ensure_int(x, name='unknown', must_be='an int'):
    """Ensure a variable is an integer."""
    # This is preferred over numbers.Integral, see:
    # https://github.com/scipy/scipy/pull/7351#issuecomment-299713159
    try:
        x = int(operator.index(x))
    except TypeError:
        raise TypeError('%s must be %s, got %s' % (name, must_be, type(x)))
    return x


def _pl(x, non_pl=''):
    """Determine if plural should be used."""
    len_x = x if isinstance(x, (integer_types, np.generic)) else len(x)
    return non_pl if len_x == 1 else 's'


def _explain_exception(start=-1, stop=None, prefix='> '):
    """Explain an exception."""
    # start=-1 means "only the most recent caller"
    etype, value, tb = sys.exc_info()
    string = traceback.format_list(traceback.extract_tb(tb)[start:stop])
    string = (''.join(string).split('\n') +
              traceback.format_exception_only(etype, value))
    string = ':\n' + prefix + ('\n' + prefix).join(string)
    return string


def _get_call_line(in_verbose=False):
    """Get the call line from within a function."""
    # XXX Eventually we could auto-triage whether in a `verbose` decorated
    # function or not.
    # NB This probably only works for functions that are undecorated,
    # or decorated by `verbose`.
    back = 2 if not in_verbose else 4
    call_frame = inspect.getouterframes(inspect.currentframe())[back][0]
    context = inspect.getframeinfo(call_frame).code_context
    context = 'unknown' if context is None else context[0].strip()
    return context


def _sort_keys(x):
    """Sort and return keys of dict."""
    keys = list(x.keys())  # note: not thread-safe
    idx = np.argsort([str(k) for k in keys])
    keys = [keys[ii] for ii in idx]
    return keys


def object_hash(x, h=None):
    """Hash a reasonable python object.

    Parameters
    ----------
    x : object
        Object to hash. Can be anything comprised of nested versions of:
        {dict, list, tuple, ndarray, str, bytes, float, int, None}.
    h : hashlib HASH object | None
        Optional, object to add the hash to. None creates an MD5 hash.

    Returns
    -------
    digest : int
        The digest resulting from the hash.
    """
    if h is None:
        h = hashlib.md5()
    if hasattr(x, 'keys'):
        # dict-like types
        keys = _sort_keys(x)
        for key in keys:
            object_hash(key, h)
            object_hash(x[key], h)
    elif isinstance(x, bytes):
        # must come before "str" below
        h.update(x)
    elif isinstance(x, (string_types, float, int, type(None))):
        h.update(str(type(x)).encode('utf-8'))
        h.update(str(x).encode('utf-8'))
    elif isinstance(x, (np.ndarray, np.number, np.bool_)):
        x = np.asarray(x)
        h.update(str(x.shape).encode('utf-8'))
        h.update(str(x.dtype).encode('utf-8'))
        h.update(x.tostring())
    elif hasattr(x, '__len__'):
        # all other list-like types
        h.update(str(type(x)).encode('utf-8'))
        for xx in x:
            object_hash(xx, h)
    else:
        raise RuntimeError('unsupported type: %s (%s)' % (type(x), x))
    return int(h.hexdigest(), 16)


def object_size(x):
    """Estimate the size of a reasonable python object.

    Parameters
    ----------
    x : object
        Object to approximate the size of.
        Can be anything comprised of nested versions of:
        {dict, list, tuple, ndarray, str, bytes, float, int, None}.

    Returns
    -------
    size : int
        The estimated size in bytes of the object.
    """
    # Note: this will not process object arrays properly (since those only)
    # hold references
    if isinstance(x, (bytes, string_types, int, float, type(None))):
        size = sys.getsizeof(x)
    elif isinstance(x, np.ndarray):
        # On newer versions of NumPy, just doing sys.getsizeof(x) works,
        # but on older ones you always get something small :(
        size = sys.getsizeof(np.array([])) + x.nbytes
    elif isinstance(x, np.generic):
        size = x.nbytes
    elif isinstance(x, dict):
        size = sys.getsizeof(x)
        for key, value in x.items():
            size += object_size(key)
            size += object_size(value)
    elif isinstance(x, (list, tuple)):
        size = sys.getsizeof(x) + sum(object_size(xx) for xx in x)
    elif sparse.isspmatrix_csc(x) or sparse.isspmatrix_csr(x):
        size = sum(sys.getsizeof(xx)
                   for xx in [x, x.data, x.indices, x.indptr])
    else:
        raise RuntimeError('unsupported type: %s (%s)' % (type(x), x))
    return size


def object_diff(a, b, pre=''):
    """Compute all differences between two python variables.

    Parameters
    ----------
    a : object
        Currently supported: dict, list, tuple, ndarray, int, str, bytes,
        float, StringIO, BytesIO.
    b : object
        Must be same type as x1.
    pre : str
        String to prepend to each line.

    Returns
    -------
    diffs : str
        A string representation of the differences.
    """
    out = ''
    if type(a) != type(b):
        out += pre + ' type mismatch (%s, %s)\n' % (type(a), type(b))
    elif isinstance(a, dict):
        k1s = _sort_keys(a)
        k2s = _sort_keys(b)
        m1 = set(k2s) - set(k1s)
        if len(m1):
            out += pre + ' left missing keys %s\n' % (m1)
        for key in k1s:
            if key not in k2s:
                out += pre + ' right missing key %s\n' % key
            else:
                out += object_diff(a[key], b[key], pre + '[%s]' % repr(key))
    elif isinstance(a, (list, tuple)):
        if len(a) != len(b):
            out += pre + ' length mismatch (%s, %s)\n' % (len(a), len(b))
        else:
            for ii, (xx1, xx2) in enumerate(zip(a, b)):
                out += object_diff(xx1, xx2, pre + '[%s]' % ii)
    elif isinstance(a, (string_types, int, float, bytes)):
        if a != b:
            out += pre + ' value mismatch (%s, %s)\n' % (a, b)
    elif a is None:
        if b is not None:
            out += pre + ' left is None, right is not (%s)\n' % (b)
    elif isinstance(a, np.ndarray):
        if not np.array_equal(a, b):
            out += pre + ' array mismatch\n'
    elif isinstance(a, (StringIO, BytesIO)):
        if a.getvalue() != b.getvalue():
            out += pre + ' StringIO mismatch\n'
    elif sparse.isspmatrix(a):
        # sparsity and sparse type of b vs a already checked above by type()
        if b.shape != a.shape:
            out += pre + (' sparse matrix a and b shape mismatch'
                          '(%s vs %s)' % (a.shape, b.shape))
        else:
            c = a - b
            c.eliminate_zeros()
            if c.nnz > 0:
                out += pre + (' sparse matrix a and b differ on %s '
                              'elements' % c.nnz)
    elif hasattr(a, '__getstate__'):
        out += object_diff(a.__getstate__(), b.__getstate__(), pre)
    else:
        raise RuntimeError(pre + ': unsupported type %s (%s)' % (type(a), a))
    return out


def check_random_state(seed):
    """Turn seed into a np.random.RandomState instance.

    If seed is None, return the RandomState singleton used by np.random.
    If seed is an int, return a new RandomState instance seeded with seed.
    If seed is already a RandomState instance, return it.
    Otherwise raise ValueError.
    """
    if seed is None or seed is np.random:
        return np.random.mtrand._rand
    if isinstance(seed, (int, np.integer)):
        return np.random.RandomState(seed)
    if isinstance(seed, np.random.RandomState):
        return seed
    raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
                     ' instance' % seed)


def split_list(l, n, idx=False):
    """Split list in n (approx) equal pieces, possibly giving indices."""
    n = int(n)
    tot = len(l)
    sz = tot // n
    start = stop = 0
    for i in range(n - 1):
        stop += sz
        yield (np.arange(start, stop), l[start:stop]) if idx else l[start:stop]
        start += sz
    yield (np.arange(start, tot), l[start:]) if idx else l[start]


def array_split_idx(ary, indices_or_sections, axis=0, n_per_split=1):
    """Do what numpy.array_split does, but add indices."""
    # this only works for indices_or_sections as int
    indices_or_sections = _ensure_int(indices_or_sections)
    ary_split = np.array_split(ary, indices_or_sections, axis=axis)
    idx_split = np.array_split(np.arange(ary.shape[axis]), indices_or_sections)
    idx_split = (np.arange(sp[0] * n_per_split, (sp[-1] + 1) * n_per_split)
                 for sp in idx_split)
    return zip(idx_split, ary_split)


def create_chunks(sequence, size):
    """Generate chunks from a sequence.

    Parameters
    ----------
    sequence : iterable
        Any iterable object
    size : int
        The chunksize to be returned
    """
    return (sequence[p:p + size] for p in range(0, len(sequence), size))


def sum_squared(X):
    """Compute norm of an array.

    Parameters
    ----------
    X : array
        Data whose norm must be found

    Returns
    -------
    value : float
        Sum of squares of the input array X
    """
    X_flat = X.ravel(order='F' if np.isfortran(X) else 'C')
    return np.dot(X_flat, X_flat)


def warn(message, category=RuntimeWarning, module='mne'):
    """Emit a warning with trace outside the mne namespace.

    This function takes arguments like warnings.warn, and sends messages
    using both ``warnings.warn`` and ``logger.warn``. Warnings can be
    generated deep within nested function calls. In order to provide a
    more helpful warning, this function traverses the stack until it
    reaches a frame outside the ``mne`` namespace that caused the error.

    Parameters
    ----------
    message : str
        Warning message.
    category : instance of Warning
        The warning class. Defaults to ``RuntimeWarning``.
    module : str
        The name of the module emitting the warning.
    """
    root_dir = op.dirname(__file__)
    frame = None
    if logger.level <= logging.WARN:
        last_fname = ''
        frame = inspect.currentframe()
        while frame:
            fname = frame.f_code.co_filename
            lineno = frame.f_lineno
            # in verbose dec
            if fname == '<string>' and last_fname == 'utils.py':
                last_fname = fname
                frame = frame.f_back
                continue
            # treat tests as scripts
            # and don't capture unittest/case.py (assert_raises)
            if not (fname.startswith(root_dir) or
                    ('unittest' in fname and 'case' in fname)) or \
                    op.basename(op.dirname(fname)) == 'tests':
                break
            last_fname = op.basename(fname)
            frame = frame.f_back
        del frame
        # We need to use this instead of warn(message, category, stacklevel)
        # because we move out of the MNE stack, so warnings won't properly
        # recognize the module name (and our warnings.simplefilter will fail)
        warnings.warn_explicit(
            message, category, fname, lineno, module,
            globals().get('__warningregistry__', {}))
    # To avoid a duplicate warning print, we only emit the logger.warning if
    # one of the handlers is a FileHandler. See gh-5592
    if any(isinstance(h, logging.FileHandler) or getattr(h, '_mne_file_like',
                                                         False)
           for h in logger.handlers):
        logger.warning(message)


def filter_out_warnings(warn_record, category=None, match=None):
    r"""Remove particular records from ``warn_record``.

    This helper takes a list of :class:`warnings.WarningMessage` objects,
    and remove those matching category and/or text.

    Parameters
    ----------
    category: WarningMessage type | None
       class of the message to filter out

    match : str | None
        text or regex that matches the error message to filter out

    Examples
    --------
    This can be used as::

        >>> import pytest
        >>> import warnings
        >>> from mne.utils import filter_out_warnings
        >>> with pytest.warns(None) as recwarn:
        ...     warnings.warn("value must be 0 or None", UserWarning)
        >>> filter_out_warnings(recwarn, match=".* 0 or None")
        >>> assert len(recwarn.list) == 0

        >>> with pytest.warns(None) as recwarn:
        ...     warnings.warn("value must be 42", UserWarning)
        >>> filter_out_warnings(recwarn, match=r'.* must be \d+$')
        >>> assert len(recwarn.list) == 0

        >>> with pytest.warns(None) as recwarn:
        ...     warnings.warn("this is not here", UserWarning)
        >>> filter_out_warnings(recwarn, match=r'.* must be \d+$')
        >>> assert len(recwarn.list) == 1
    """
    regexp = re.compile('.*' if match is None else match)
    is_category = [w.category == category if category is not None else True
                   for w in warn_record._list]
    is_match = [regexp.match(w.message.args[0]) is not None
                for w in warn_record._list]
    ind = [ind for ind, (c, m) in enumerate(zip(is_category, is_match))
           if c and m]

    for i in reversed(ind):
        warn_record._list.pop(i)


def check_fname(fname, filetype, endings, endings_err=()):
    """Enforce MNE filename conventions.

    Parameters
    ----------
    fname : str
        Name of the file.
    filetype : str
        Type of file. e.g., ICA, Epochs etc.
    endings : tuple
        Acceptable endings for the filename.
    endings_err : tuple
        Obligatory possible endings for the filename.
    """
    if len(endings_err) > 0 and not fname.endswith(endings_err):
        print_endings = ' or '.join([', '.join(endings_err[:-1]),
                                     endings_err[-1]])
        raise IOError('The filename (%s) for file type %s must end with %s'
                      % (fname, filetype, print_endings))
    print_endings = ' or '.join([', '.join(endings[:-1]), endings[-1]])
    if not fname.endswith(endings):
        warn('This filename (%s) does not conform to MNE naming conventions. '
             'All %s files should end with %s'
             % (fname, filetype, print_endings))


class _Counter():
    count = 1

    def __call__(self, *args, **kargs):
        c = self.count
        self.count += 1
        return c


class WrapStdOut(object):
    """Dynamically wrap to sys.stdout.

    This makes packages that monkey-patch sys.stdout (e.g.doctest,
    sphinx-gallery) work properly.
    """

    def __getattr__(self, name):  # noqa: D105
        # Even more ridiculous than this class, this must be sys.stdout (not
        # just stdout) in order for this to work (tested on OSX and Linux)
        if hasattr(sys.stdout, name):
            return getattr(sys.stdout, name)
        else:
            raise AttributeError("'file' object has not attribute '%s'" % name)


class _TempDir(str):
    """Create and auto-destroy temp dir.

    This is designed to be used with testing modules. Instances should be
    defined inside test functions. Instances defined at module level can not
    guarantee proper destruction of the temporary directory.

    When used at module level, the current use of the __del__() method for
    cleanup can fail because the rmtree function may be cleaned up before this
    object (an alternative could be using the atexit module instead).
    """

    def __new__(self):  # noqa: D105
        new = str.__new__(self, tempfile.mkdtemp(prefix='tmp_mne_tempdir_'))
        return new

    def __init__(self):  # noqa: D102
        self._path = self.__str__()

    def __del__(self):  # noqa: D105
        rmtree(self._path, ignore_errors=True)


def estimate_rank(data, tol='auto', return_singular=False, norm=True):
    """Estimate the rank of data.

    This function will normalize the rows of the data (typically
    channels or vertices) such that non-zero singular values
    should be close to one.

    Parameters
    ----------
    data : array
        Data to estimate the rank of (should be 2-dimensional).
    tol : float | 'auto'
        Tolerance for singular values to consider non-zero in
        calculating the rank. The singular values are calculated
        in this method such that independent data are expected to
        have singular value around one. Can be 'auto' to use the
        same thresholding as ``scipy.linalg.orth``.
    return_singular : bool
        If True, also return the singular values that were used
        to determine the rank.
    norm : bool
        If True, data will be scaled by their estimated row-wise norm.
        Else data are assumed to be scaled. Defaults to True.

    Returns
    -------
    rank : int
        Estimated rank of the data.
    s : array
        If return_singular is True, the singular values that were
        thresholded to determine the rank are also returned.
    """
    data = data.copy()  # operate on a copy
    if norm is True:
        norms = _compute_row_norms(data)
        data /= norms[:, np.newaxis]
    s = linalg.svd(data, compute_uv=False, overwrite_a=True)
    rank = _estimate_rank_from_s(s, tol)
    if return_singular is True:
        return rank, s
    else:
        return rank


def _estimate_rank_from_s(s, tol='auto'):
    """Estimate the rank of a matrix from its singular values.

    Parameters
    ----------
    s : list of float
        The singular values of the matrix.
    tol : float | 'auto'
        Tolerance for singular values to consider non-zero in calculating the
        rank. Can be 'auto' to use the same thresholding as
        ``scipy.linalg.orth``.

    Returns
    -------
    rank : int
        The estimated rank.
    """
    if isinstance(tol, string_types):
        if tol != 'auto':
            raise ValueError('tol must be "auto" or float')
        eps = np.finfo(float).eps
        tol = len(s) * np.amax(s) * eps

    tol = float(tol)
    rank = np.sum(s > tol)
    return rank


def _compute_row_norms(data):
    """Compute scaling based on estimated norm."""
    norms = np.sqrt(np.sum(data ** 2, axis=1))
    norms[norms == 0] = 1.0
    return norms


def _reg_pinv(x, reg=0, rank='full', rcond=1e-15):
    """Compute a regularized pseudoinverse of a square matrix.

    Regularization is performed by adding a constant value to each diagonal
    element of the matrix before inversion. This is known as "diagonal
    loading". The loading factor is computed as ``reg * np.trace(x) / len(x)``.

    The pseudo-inverse is computed through SVD decomposition and inverting the
    singular values. When the matrix is rank deficient, some singular values
    will be close to zero and will not be used during the inversion. The number
    of singular values to use can either be manually specified or automatically
    estimated.

    Parameters
    ----------
    x : ndarray, shape (n, n)
        Square matrix to invert.
    reg : float
        Regularization parameter. Defaults to 0.
    rank : int | None | 'full'
        This controls the effective rank of the covariance matrix when
        computing the inverse. The rank can be set explicitly by specifying an
        integer value. If ``None``, the rank will be automatically estimated.
        Since applying regularization will always make the covariance matrix
        full rank, the rank is estimated before regularization in this case. If
        'full', the rank will be estimated after regularization and hence
        will mean using the full rank, unless ``reg=0`` is used.
        Defaults to 'full'.
    rcond : float | 'auto'
        Cutoff for detecting small singular values when attempting to estimate
        the rank of the matrix (``rank='auto'``). Singular values smaller than
        the cutoff are set to zero. When set to 'auto', a cutoff based on
        floating point precision will be used. Defaults to 1e-15.

    Returns
    -------
    x_inv : ndarray, shape (n, n)
        The inverted matrix.
    loading_factor : float
        Value added to the diagonal of the matrix during regularization.
    rank : int
        If ``rank`` was set to an integer value, this value is returned,
        else the estimated rank of the matrix, before regularization, is
        returned.
    """
    if rank is not None and rank != 'full':
        rank = int(operator.index(rank))
    if x.ndim != 2 or x.shape[0] != x.shape[1]:
        raise ValueError('Input matrix must be square.')
    if not np.allclose(x, x.conj().T):
        raise ValueError('Input matrix must be Hermitian (symmetric)')

    # Decompose the matrix
    U, s, V = linalg.svd(x)

    # Estimate the rank before regularization
    tol = 'auto' if rcond == 'auto' else rcond * s.max()
    rank_before = _estimate_rank_from_s(s, tol)

    # Decompose the matrix again after regularization
    loading_factor = reg * np.mean(s)
    U, s, V = linalg.svd(x + loading_factor * np.eye(len(x)))

    # Estimate the rank after regularization
    tol = 'auto' if rcond == 'auto' else rcond * s.max()
    rank_after = _estimate_rank_from_s(s, tol)

    # Warn the user if both all parameters were kept at their defaults and the
    # matrix is rank deficient.
    if rank_after < len(x) and reg == 0 and rank == 'full' and rcond == 1e-15:
        warn('Covariance matrix is rank-deficient and no regularization is '
             'done.')
    elif isinstance(rank, int) and rank > len(x):
        raise ValueError('Invalid value for the rank parameter (%d) given '
                         'the shape of the input matrix (%d x %d).' %
                         (rank, x.shape[0], x.shape[1]))

    # Pick the requested number of singular values
    if rank is None:
        sel_s = s[:rank_before]
    elif rank == 'full':
        sel_s = s[:rank_after]
    else:
        sel_s = s[:rank]

    # Invert only non-zero singular values
    s_inv = np.zeros(s.shape)
    nonzero_inds = np.flatnonzero(sel_s != 0)
    if len(nonzero_inds) > 0:
        s_inv[nonzero_inds] = 1. / sel_s[nonzero_inds]

    # Compute the pseudo inverse
    x_inv = np.dot(V.T, s_inv[:, np.newaxis] * U.T)

    if rank is None or rank == 'full':
        return x_inv, loading_factor, rank_before
    else:
        return x_inv, loading_factor, rank


def _reject_data_segments(data, reject, flat, decim, info, tstep):
    """Reject data segments using peak-to-peak amplitude."""
    from .epochs import _is_good
    from .io.pick import channel_indices_by_type

    data_clean = np.empty_like(data)
    idx_by_type = channel_indices_by_type(info)
    step = int(ceil(tstep * info['sfreq']))
    if decim is not None:
        step = int(ceil(step / float(decim)))
    this_start = 0
    this_stop = 0
    drop_inds = []
    for first in range(0, data.shape[1], step):
        last = first + step
        data_buffer = data[:, first:last]
        if data_buffer.shape[1] < (last - first):
            break  # end of the time segment
        if _is_good(data_buffer, info['ch_names'], idx_by_type, reject,
                    flat, ignore_chs=info['bads']):
            this_stop = this_start + data_buffer.shape[1]
            data_clean[:, this_start:this_stop] = data_buffer
            this_start += data_buffer.shape[1]
        else:
            logger.info("Artifact detected in [%d, %d]" % (first, last))
            drop_inds.append((first, last))
    data = data_clean[:, :this_stop]
    if not data.any():
        raise RuntimeError('No clean segment found. Please '
                           'consider updating your rejection '
                           'thresholds.')
    return data, drop_inds


def _get_inst_data(inst):
    """Get data view from MNE object instance like Raw, Epochs or Evoked."""
    from .io.base import BaseRaw
    from .epochs import BaseEpochs
    from . import Evoked
    from .time_frequency.tfr import _BaseTFR

    _validate_type(inst, (BaseRaw, BaseEpochs, Evoked, _BaseTFR), "Instance")
    if not inst.preload:
        inst.load_data()
    return inst._data


class _FormatDict(dict):
    """Help pformat() work properly."""

    def __missing__(self, key):
        return "{" + key + "}"


def pformat(temp, **fmt):
    """Format a template string partially.

    Examples
    --------
    >>> pformat("{a}_{b}", a='x')
    'x_{b}'
    """
    formatter = Formatter()
    mapping = _FormatDict(fmt)
    return formatter.vformat(temp, (), mapping)


###############################################################################
# DECORATORS

# Following deprecated class copied from scikit-learn

# force show of DeprecationWarning even on python 2.7
warnings.filterwarnings('always', category=DeprecationWarning, module='mne')


class deprecated(object):
    """Mark a function or class as deprecated (decorator).

    Issue a warning when the function is called/the class is instantiated and
    adds a warning to the docstring.

    The optional extra argument will be appended to the deprecation message
    and the docstring. Note: to use this with the default value for extra, put
    in an empty of parentheses::

        >>> from mne.utils import deprecated
        >>> deprecated() # doctest: +ELLIPSIS
        <mne.utils.deprecated object at ...>

        >>> @deprecated()
        ... def some_function(): pass


    Parameters
    ----------
    extra: string
        To be added to the deprecation messages.
    """

    # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
    # but with many changes.

    # scikit-learn will not import on all platforms b/c it can be
    # sklearn or scikits.learn, so a self-contained example is used above

    def __init__(self, extra=''):  # noqa: D102
        self.extra = extra

    def __call__(self, obj):  # noqa: D105
        """Call.

        Parameters
        ----------
        obj : object
            Object to call.
        """
        if isinstance(obj, type):
            return self._decorate_class(obj)
        else:
            return self._decorate_fun(obj)

    def _decorate_class(self, cls):
        msg = "Class %s is deprecated" % cls.__name__
        if self.extra:
            msg += "; %s" % self.extra

        # FIXME: we should probably reset __new__ for full generality
        init = cls.__init__

        def deprecation_wrapped(*args, **kwargs):
            warnings.warn(msg, category=DeprecationWarning)
            return init(*args, **kwargs)
        cls.__init__ = deprecation_wrapped

        deprecation_wrapped.__name__ = '__init__'
        deprecation_wrapped.__doc__ = self._update_doc(init.__doc__)
        deprecation_wrapped.deprecated_original = init

        return cls

    def _decorate_fun(self, fun):
        """Decorate function fun."""
        msg = "Function %s is deprecated" % fun.__name__
        if self.extra:
            msg += "; %s" % self.extra

        def deprecation_wrapped(*args, **kwargs):
            warnings.warn(msg, category=DeprecationWarning)
            return fun(*args, **kwargs)

        deprecation_wrapped.__name__ = fun.__name__
        deprecation_wrapped.__dict__ = fun.__dict__
        deprecation_wrapped.__doc__ = self._update_doc(fun.__doc__)

        return deprecation_wrapped

    def _update_doc(self, olddoc):
        newdoc = ".. warning:: DEPRECATED"
        if self.extra:
            newdoc = "%s: %s" % (newdoc, self.extra)
        if olddoc:
            # Get the spacing right to avoid sphinx warnings
            n_space = 4
            for li, line in enumerate(olddoc.split('\n')):
                if li > 0 and len(line.strip()):
                    n_space = len(line) - len(line.lstrip())
                    break
            newdoc = "%s\n\n%s%s" % (newdoc, ' ' * n_space, olddoc)
        return newdoc


@decorator
def verbose(function, *args, **kwargs):
    """Verbose decorator to allow functions to override log-level.

    This decorator is used to set the verbose level during a function or method
    call, such as :func:`mne.compute_covariance`. The `verbose` keyword
    argument can be 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', True (an
    alias for 'INFO'), or False (an alias for 'WARNING'). To set the global
    verbosity level for all functions, use :func:`mne.set_log_level`.

    Parameters
    ----------
    function : function
        Function to be decorated by setting the verbosity level.

    Returns
    -------
    dec : function
        The decorated function

    Examples
    --------
    You can use the ``verbose`` argument to set the verbose level on the fly::
        >>> import mne
        >>> cov = mne.compute_raw_covariance(raw, verbose='WARNING')  # doctest: +SKIP
        >>> cov = mne.compute_raw_covariance(raw, verbose='INFO')  # doctest: +SKIP
        Using up to 49 segments
        Number of samples used : 5880
        [done]

    See Also
    --------
    set_log_level
    set_config
    """  # noqa: E501
    arg_names = _get_args(function)
    default_level = verbose_level = None
    if len(arg_names) > 0 and arg_names[0] == 'self':
        default_level = getattr(args[0], 'verbose', None)
    if 'verbose' in arg_names:
        verbose_level = args[arg_names.index('verbose')]
    elif 'verbose' in kwargs:
        verbose_level = kwargs.pop('verbose')

    # This ensures that object.method(verbose=None) will use object.verbose
    verbose_level = default_level if verbose_level is None else verbose_level

    if verbose_level is not None:
        # set it back if we get an exception
        with use_log_level(verbose_level):
            return function(*args, **kwargs)
    return function(*args, **kwargs)


class use_log_level(object):
    """Context handler for logging level.

    Parameters
    ----------
    level : int
        The level to use.
    """

    def __init__(self, level):  # noqa: D102
        self.level = level

    def __enter__(self):  # noqa: D105
        self.old_level = set_log_level(self.level, True)

    def __exit__(self, *args):  # noqa: D105
        set_log_level(self.old_level)


def has_nibabel(vox2ras_tkr=False):
    """Determine if nibabel is installed.

    Parameters
    ----------
    vox2ras_tkr : bool
        If True, require nibabel has vox2ras_tkr support.

    Returns
    -------
    has : bool
        True if the user has nibabel.
    """
    try:
        import nibabel
        out = True
        if vox2ras_tkr:  # we need MGHHeader to have vox2ras_tkr param
            out = (getattr(getattr(getattr(nibabel, 'MGHImage', 0),
                                   'header_class', 0),
                           'get_vox2ras_tkr', None) is not None)
        return out
    except ImportError:
        return False


def has_mne_c():
    """Check for MNE-C."""
    return 'MNE_ROOT' in os.environ


def has_freesurfer():
    """Check for Freesurfer."""
    return 'FREESURFER_HOME' in os.environ


def requires_nibabel(vox2ras_tkr=False):
    """Check for nibabel."""
    import pytest
    extra = ' with vox2ras_tkr support' if vox2ras_tkr else ''
    return pytest.mark.skipif(not has_nibabel(vox2ras_tkr),
                              reason='Requires nibabel%s' % extra)


def requires_dipy():
    """Check for dipy."""
    import pytest
    # for some strange reason on CIs we cane get:
    #
    #     can get weird ImportError: dlopen: cannot load any more object
    #     with static TLS
    #
    # so let's import everything in the decorator.
    try:
        from dipy.align import imaffine, imwarp, metrics, transforms  # noqa, analysis:ignore
        from dipy.align.reslice import reslice  # noqa, analysis:ignore
        from dipy.align.imaffine import AffineMap  # noqa, analysis:ignore
        from dipy.align.imwarp import DiffeomorphicMap  # noqa, analysis:ignore
    except Exception:
        have = False
    else:
        have = True
    return pytest.mark.skipif(not have, reason='Requires dipy >= 0.10.1')


def buggy_mkl_svd(function):
    """Decorate tests that make calls to SVD and intermittently fail."""
    @wraps(function)
    def dec(*args, **kwargs):
        try:
            return function(*args, **kwargs)
        except np.linalg.LinAlgError as exp:
            if 'SVD did not converge' in str(exp):
                msg = 'Intel MKL SVD convergence error detected, skipping test'
                warn(msg)
                raise SkipTest(msg)
            raise
    return dec


def requires_version(library, min_version='0.0'):
    """Check for a library version."""
    import pytest
    return pytest.mark.skipif(not check_version(library, min_version),
                              reason=('Requires %s version >= %s'
                                      % (library, min_version)))


def requires_module(function, name, call=None):
    """Skip a test if package is not available (decorator)."""
    import pytest
    call = ('import %s' % name) if call is None else call
    reason = 'Test %s skipped, requires %s.' % (function.__name__, name)
    try:
        exec(call) in globals(), locals()
    except Exception as exc:
        if len(str(exc)) > 0 and str(exc) != 'No module named %s' % name:
            reason += ' Got exception (%s)' % (exc,)
        skip = True
    else:
        skip = False
    return pytest.mark.skipif(skip, reason=reason)(function)


def copy_doc(source):
    """Copy the docstring from another function (decorator).

    The docstring of the source function is prepepended to the docstring of the
    function wrapped by this decorator.

    This is useful when inheriting from a class and overloading a method. This
    decorator can be used to copy the docstring of the original method.

    Parameters
    ----------
    source : function
        Function to copy the docstring from

    Returns
    -------
    wrapper : function
        The decorated function

    Examples
    --------
    >>> class A:
    ...     def m1():
    ...         '''Docstring for m1'''
    ...         pass
    >>> class B (A):
    ...     @copy_doc(A.m1)
    ...     def m1():
    ...         ''' this gets appended'''
    ...         pass
    >>> print(B.m1.__doc__)
    Docstring for m1 this gets appended
    """
    def wrapper(func):
        if source.__doc__ is None or len(source.__doc__) == 0:
            raise ValueError('Cannot copy docstring: docstring was empty.')
        doc = source.__doc__
        if func.__doc__ is not None:
            doc += func.__doc__
        func.__doc__ = doc
        return func
    return wrapper


def copy_function_doc_to_method_doc(source):
    """Use the docstring from a function as docstring for a method.

    The docstring of the source function is prepepended to the docstring of the
    function wrapped by this decorator. Additionally, the first parameter
    specified in the docstring of the source function is removed in the new
    docstring.

    This decorator is useful when implementing a method that just calls a
    function.  This pattern is prevalent in for example the plotting functions
    of MNE.

    Parameters
    ----------
    source : function
        Function to copy the docstring from

    Returns
    -------
    wrapper : function
        The decorated method

    Examples
    --------
    >>> def plot_function(object, a, b):
    ...     '''Docstring for plotting function.
    ...
    ...     Parameters
    ...     ----------
    ...     object : instance of object
    ...         The object to plot
    ...     a : int
    ...         Some parameter
    ...     b : int
    ...         Some parameter
    ...     '''
    ...     pass
    ...
    >>> class A:
    ...     @copy_function_doc_to_method_doc(plot_function)
    ...     def plot(self, a, b):
    ...         '''
    ...         Notes
    ...         -----
    ...         .. versionadded:: 0.13.0
    ...         '''
    ...         plot_function(self, a, b)
    >>> print(A.plot.__doc__)
    Docstring for plotting function.
    <BLANKLINE>
        Parameters
        ----------
        a : int
            Some parameter
        b : int
            Some parameter
    <BLANKLINE>
            Notes
            -----
            .. versionadded:: 0.13.0
    <BLANKLINE>

    Notes
    -----
    The parsing performed is very basic and will break easily on docstrings
    that are not formatted exactly according to the ``numpydoc`` standard.
    Always inspect the resulting docstring when using this decorator.
    """
    def wrapper(func):
        doc = source.__doc__.split('\n')

        # Find parameter block
        for line, text in enumerate(doc[:-2]):
            if (text.strip() == 'Parameters' and
                    doc[line + 1].strip() == '----------'):
                parameter_block = line
                break
        else:
            # No parameter block found
            raise ValueError('Cannot copy function docstring: no parameter '
                             'block found. To simply copy the docstring, use '
                             'the @copy_doc decorator instead.')

        # Find first parameter
        for line, text in enumerate(doc[parameter_block:], parameter_block):
            if ':' in text:
                first_parameter = line
                parameter_indentation = len(text) - len(text.lstrip(' '))
                break
        else:
            raise ValueError('Cannot copy function docstring: no parameters '
                             'found. To simply copy the docstring, use the '
                             '@copy_doc decorator instead.')

        # Find end of first parameter
        for line, text in enumerate(doc[first_parameter + 1:],
                                    first_parameter + 1):
            # Ignore empty lines
            if len(text.strip()) == 0:
                continue

            line_indentation = len(text) - len(text.lstrip(' '))
            if line_indentation <= parameter_indentation:
                # Reach end of first parameter
                first_parameter_end = line

                # Of only one parameter is defined, remove the Parameters
                # heading as well
                if ':' not in text:
                    first_parameter = parameter_block

                break
        else:
            # End of docstring reached
            first_parameter_end = line
            first_parameter = parameter_block

        # Copy the docstring, but remove the first parameter
        doc = ('\n'.join(doc[:first_parameter]) + '\n' +
               '\n'.join(doc[first_parameter_end:]))
        if func.__doc__ is not None:
            doc += func.__doc__
        func.__doc__ = doc
        return func
    return wrapper


_pandas_call = """
import pandas
version = LooseVersion(pandas.__version__)
if version < '0.8.0':
    raise ImportError
"""

_sklearn_call = """
required_version = '0.14'
import sklearn
version = LooseVersion(sklearn.__version__)
if version < required_version:
    raise ImportError
"""

_mayavi_call = """
with warnings.catch_warnings(record=True):  # traits
    from mayavi import mlab
mlab.options.backend = 'test'
"""

_mne_call = """
if not has_mne_c():
    raise ImportError
"""

_fs_call = """
if not has_freesurfer():
    raise ImportError
"""

_n2ft_call = """
if 'NEUROMAG2FT_ROOT' not in os.environ:
    raise ImportError
"""

_fs_or_ni_call = """
if not has_nibabel() and not has_freesurfer():
    raise ImportError
"""

requires_pandas = partial(requires_module, name='pandas', call=_pandas_call)
requires_sklearn = partial(requires_module, name='sklearn', call=_sklearn_call)
requires_mayavi = partial(requires_module, name='mayavi', call=_mayavi_call)
requires_mne = partial(requires_module, name='MNE-C', call=_mne_call)
requires_freesurfer = partial(requires_module, name='Freesurfer',
                              call=_fs_call)
requires_neuromag2ft = partial(requires_module, name='neuromag2ft',
                               call=_n2ft_call)
requires_fs_or_nibabel = partial(requires_module, name='nibabel or Freesurfer',
                                 call=_fs_or_ni_call)

requires_tvtk = partial(requires_module, name='TVTK',
                        call='from tvtk.api import tvtk')
requires_pysurfer = partial(requires_module, name='PySurfer',
                            call="""import warnings
with warnings.catch_warnings(record=True):
    from surfer import Brain""")
requires_good_network = partial(
    requires_module, name='good network connection',
    call='if int(os.environ.get("MNE_SKIP_NETWORK_TESTS", 0)):\n'
         '    raise ImportError')
requires_nitime = partial(requires_module, name='nitime')
requires_h5py = partial(requires_module, name='h5py')
requires_numpydoc = partial(requires_module, name='numpydoc')


def check_version(library, min_version):
    r"""Check minimum library version required.

    Parameters
    ----------
    library : str
        The library name to import. Must have a ``__version__`` property.
    min_version : str
        The minimum version string. Anything that matches
        ``'(\d+ | [a-z]+ | \.)'``. Can also be empty to skip version
        check (just check for library presence).

    Returns
    -------
    ok : bool
        True if the library exists with at least the specified version.
    """
    ok = True
    try:
        library = __import__(library)
    except ImportError:
        ok = False
    else:
        if min_version:
            this_version = LooseVersion(library.__version__)
            if this_version < min_version:
                ok = False
    return ok


def _check_mayavi_version(min_version='4.3.0'):
    """Check mayavi version."""
    if not check_version('mayavi', min_version):
        raise RuntimeError("Need mayavi >= %s" % min_version)


def _check_pyface_backend():
    """Check the currently selected Pyface backend.

    Returns
    -------
    backend : str
        Name of the backend.
    result : 0 | 1 | 2
        0: the backend has been tested and works.
        1: the backend has not been tested.
        2: the backend not been tested.

    Notes
    -----
    See also http://docs.enthought.com/pyface/.
    """
    try:
        from traits.trait_base import ETSConfig
    except ImportError:
        return None, 2

    backend = ETSConfig.toolkit
    if backend == 'qt4':
        status = 0
    else:
        status = 1
    return backend, status


def _import_mlab():
    """Quietly import mlab."""
    with warnings.catch_warnings(record=True):
        from mayavi import mlab
    return mlab


@contextmanager
def traits_test_context():
    """Context to raise errors in trait handlers."""
    from traits.api import push_exception_handler

    push_exception_handler(reraise_exceptions=True)
    yield
    push_exception_handler(reraise_exceptions=False)


def traits_test(test_func):
    """Raise errors in trait handlers (decorator)."""
    @wraps(test_func)
    def dec(*args, **kwargs):
        with traits_test_context():
            return test_func(*args, **kwargs)
    return dec


@verbose
def run_subprocess(command, verbose=None, *args, **kwargs):
    """Run command using subprocess.Popen.

    Run command and wait for command to complete. If the return code was zero
    then return, otherwise raise CalledProcessError.
    By default, this will also add stdout= and stderr=subproces.PIPE
    to the call to Popen to suppress printing to the terminal.

    Parameters
    ----------
    command : list of str | str
        Command to run as subprocess (see subprocess.Popen documentation).
    verbose : bool, str, int, or None
        If not None, override default verbose level (see :func:`mne.verbose`
        and :ref:`Logging documentation <tut_logging>` for more). Defaults to
        self.verbose.
    *args, **kwargs : arguments
        Additional arguments to pass to subprocess.Popen.

    Returns
    -------
    stdout : str
        Stdout returned by the process.
    stderr : str
        Stderr returned by the process.
    """
    for stdxxx, sys_stdxxx, thresh in (
            ['stderr', sys.stderr, logging.ERROR],
            ['stdout', sys.stdout, logging.WARNING]):
        if stdxxx not in kwargs and logger.level >= thresh:
            kwargs[stdxxx] = subprocess.PIPE
        elif kwargs.get(stdxxx, sys_stdxxx) is sys_stdxxx:
            if isinstance(sys_stdxxx, StringIO):
                # nose monkey patches sys.stderr and sys.stdout to StringIO
                kwargs[stdxxx] = subprocess.PIPE
            else:
                kwargs[stdxxx] = sys_stdxxx

    # Check the PATH environment variable. If run_subprocess() is to be called
    # frequently this should be refactored so as to only check the path once.
    env = kwargs.get('env', os.environ)
    if any(p.startswith('~') for p in env['PATH'].split(os.pathsep)):
        warn('Your PATH environment variable contains at least one path '
             'starting with a tilde ("~") character. Such paths are not '
             'interpreted correctly from within Python. It is recommended '
             'that you use "$HOME" instead of "~".')
    if isinstance(command, string_types):
        command_str = command
    else:
        command_str = ' '.join(command)
    logger.info("Running subprocess: %s" % command_str)
    try:
        p = subprocess.Popen(command, *args, **kwargs)
    except Exception:
        if isinstance(command, string_types):
            command_name = command.split()[0]
        else:
            command_name = command[0]
        logger.error('Command not found: %s' % command_name)
        raise
    stdout_, stderr = p.communicate()
    stdout_ = u'' if stdout_ is None else stdout_.decode('utf-8')
    stderr = u'' if stderr is None else stderr.decode('utf-8')
    output = (stdout_, stderr)

    if p.returncode:
        print(output)
        err_fun = subprocess.CalledProcessError.__init__
        if 'output' in _get_args(err_fun):
            raise subprocess.CalledProcessError(p.returncode, command, output)
        else:
            raise subprocess.CalledProcessError(p.returncode, command)

    return output


###############################################################################
# LOGGING

def set_log_level(verbose=None, return_old_level=False):
    """Set the logging level.

    Parameters
    ----------
    verbose : bool, str, int, or None
        The verbosity of messages to print. If a str, it can be either DEBUG,
        INFO, WARNING, ERROR, or CRITICAL. Note that these are for
        convenience and are equivalent to passing in logging.DEBUG, etc.
        For bool, True is the same as 'INFO', False is the same as 'WARNING'.
        If None, the environment variable MNE_LOGGING_LEVEL is read, and if
        it doesn't exist, defaults to INFO.
    return_old_level : bool
        If True, return the old verbosity level.
    """
    if verbose is None:
        verbose = get_config('MNE_LOGGING_LEVEL', 'INFO')
    elif isinstance(verbose, bool):
        if verbose is True:
            verbose = 'INFO'
        else:
            verbose = 'WARNING'
    if isinstance(verbose, string_types):
        verbose = verbose.upper()
        logging_types = dict(DEBUG=logging.DEBUG, INFO=logging.INFO,
                             WARNING=logging.WARNING, ERROR=logging.ERROR,
                             CRITICAL=logging.CRITICAL)
        if verbose not in logging_types:
            raise ValueError('verbose must be of a valid type')
        verbose = logging_types[verbose]
    logger = logging.getLogger('mne')
    old_verbose = logger.level
    logger.setLevel(verbose)
    return (old_verbose if return_old_level else None)


def set_log_file(fname=None, output_format='%(message)s', overwrite=None):
    """Set the log to print to a file.

    Parameters
    ----------
    fname : str, or None
        Filename of the log to print to. If None, stdout is used.
        To suppress log outputs, use set_log_level('WARN').
    output_format : str
        Format of the output messages. See the following for examples:

            https://docs.python.org/dev/howto/logging.html

        e.g., "%(asctime)s - %(levelname)s - %(message)s".
    overwrite : bool | None
        Overwrite the log file (if it exists). Otherwise, statements
        will be appended to the log (default). None is the same as False,
        but additionally raises a warning to notify the user that log
        entries will be appended.
    """
    logger = logging.getLogger('mne')
    handlers = logger.handlers
    for h in handlers:
        # only remove our handlers (get along nicely with nose)
        if isinstance(h, (logging.FileHandler, logging.StreamHandler)):
            if isinstance(h, logging.FileHandler):
                h.close()
            logger.removeHandler(h)
    if fname is not None:
        if op.isfile(fname) and overwrite is None:
            # Don't use warn() here because we just want to
            # emit a warnings.warn here (not logger.warn)
            warnings.warn('Log entries will be appended to the file. Use '
                          'overwrite=False to avoid this message in the '
                          'future.', RuntimeWarning, stacklevel=2)
            overwrite = False
        mode = 'w' if overwrite else 'a'
        lh = logging.FileHandler(fname, mode=mode)
    else:
        """ we should just be able to do:
                lh = logging.StreamHandler(sys.stdout)
            but because doctests uses some magic on stdout, we have to do this:
        """
        lh = logging.StreamHandler(WrapStdOut())

    lh.setFormatter(logging.Formatter(output_format))
    # actually add the stream handler
    logger.addHandler(lh)


class catch_logging(object):
    """Store logging.

    This will remove all other logging handlers, and return the handler to
    stdout when complete.
    """

    def __enter__(self):  # noqa: D105
        self._data = StringIO()
        self._lh = logging.StreamHandler(self._data)
        self._lh.setFormatter(logging.Formatter('%(message)s'))
        self._lh._mne_file_like = True  # monkey patch for warn() use
        for lh in logger.handlers:
            logger.removeHandler(lh)
        logger.addHandler(self._lh)
        return self._data

    def __exit__(self, *args):  # noqa: D105
        logger.removeHandler(self._lh)
        set_log_file(None)


###############################################################################
# CONFIG / PREFS

def get_subjects_dir(subjects_dir=None, raise_error=False):
    """Safely use subjects_dir input to return SUBJECTS_DIR.

    Parameters
    ----------
    subjects_dir : str | None
        If a value is provided, return subjects_dir. Otherwise, look for
        SUBJECTS_DIR config and return the result.
    raise_error : bool
        If True, raise a KeyError if no value for SUBJECTS_DIR can be found
        (instead of returning None).

    Returns
    -------
    value : str | None
        The SUBJECTS_DIR value.
    """
    if subjects_dir is None:
        subjects_dir = get_config('SUBJECTS_DIR', raise_error=raise_error)
    return subjects_dir


_temp_home_dir = None


def _get_extra_data_path(home_dir=None):
    """Get path to extra data (config, tables, etc.)."""
    global _temp_home_dir
    if home_dir is None:
        home_dir = os.environ.get('_MNE_FAKE_HOME_DIR')
    if home_dir is None:
        # this has been checked on OSX64, Linux64, and Win32
        if 'nt' == os.name.lower():
            if op.isdir(op.join(os.getenv('APPDATA'), '.mne')):
                home_dir = os.getenv('APPDATA')
            else:
                home_dir = os.getenv('USERPROFILE')
        else:
            # This is a more robust way of getting the user's home folder on
            # Linux platforms (not sure about OSX, Unix or BSD) than checking
            # the HOME environment variable. If the user is running some sort
            # of script that isn't launched via the command line (e.g. a script
            # launched via Upstart) then the HOME environment variable will
            # not be set.
            if os.getenv('MNE_DONTWRITE_HOME', '') == 'true':
                if _temp_home_dir is None:
                    _temp_home_dir = tempfile.mkdtemp()
                    atexit.register(partial(shutil.rmtree, _temp_home_dir,
                                            ignore_errors=True))
                home_dir = _temp_home_dir
            else:
                home_dir = os.path.expanduser('~')

        if home_dir is None:
            raise ValueError('mne-python config file path could '
                             'not be determined, please report this '
                             'error to mne-python developers')

    return op.join(home_dir, '.mne')


def get_config_path(home_dir=None):
    r"""Get path to standard mne-python config file.

    Parameters
    ----------
    home_dir : str | None
        The folder that contains the .mne config folder.
        If None, it is found automatically.

    Returns
    -------
    config_path : str
        The path to the mne-python configuration file. On windows, this
        will be '%USERPROFILE%\.mne\mne-python.json'. On every other
        system, this will be ~/.mne/mne-python.json.
    """
    val = op.join(_get_extra_data_path(home_dir=home_dir),
                  'mne-python.json')
    return val


def set_cache_dir(cache_dir):
    """Set the directory to be used for temporary file storage.

    This directory is used by joblib to store memmapped arrays,
    which reduces memory requirements and speeds up parallel
    computation.

    Parameters
    ----------
    cache_dir: str or None
        Directory to use for temporary file storage. None disables
        temporary file storage.
    """
    if cache_dir is not None and not op.exists(cache_dir):
        raise IOError('Directory %s does not exist' % cache_dir)

    set_config('MNE_CACHE_DIR', cache_dir, set_env=False)


def set_memmap_min_size(memmap_min_size):
    """Set the minimum size for memmaping of arrays for parallel processing.

    Parameters
    ----------
    memmap_min_size: str or None
        Threshold on the minimum size of arrays that triggers automated memory
        mapping for parallel processing, e.g., '1M' for 1 megabyte.
        Use None to disable memmaping of large arrays.
    """
    if memmap_min_size is not None:
        if not isinstance(memmap_min_size, string_types):
            raise ValueError('\'memmap_min_size\' has to be a string.')
        if memmap_min_size[-1] not in ['K', 'M', 'G']:
            raise ValueError('The size has to be given in kilo-, mega-, or '
                             'gigabytes, e.g., 100K, 500M, 1G.')

    set_config('MNE_MEMMAP_MIN_SIZE', memmap_min_size, set_env=False)


# List the known configuration values
known_config_types = (
    'MNE_BROWSE_RAW_SIZE',
    'MNE_CACHE_DIR',
    'MNE_COREG_COPY_ANNOT',
    'MNE_COREG_GUESS_MRI_SUBJECT',
    'MNE_COREG_HEAD_HIGH_RES',
    'MNE_COREG_HEAD_OPACITY',
    'MNE_COREG_INTERACTION',
    'MNE_COREG_MARK_INSIDE',
    'MNE_COREG_PREPARE_BEM',
    'MNE_COREG_PROJECT_EEG',
    'MNE_COREG_ORIENT_TO_SURFACE',
    'MNE_COREG_SCALE_LABELS',
    'MNE_COREG_SCALE_BY_DISTANCE',
    'MNE_COREG_SCENE_SCALE',
    'MNE_COREG_WINDOW_HEIGHT',
    'MNE_COREG_WINDOW_WIDTH',
    'MNE_COREG_SUBJECTS_DIR',
    'MNE_CUDA_IGNORE_PRECISION',
    'MNE_DATA',
    'MNE_DATASETS_BRAINSTORM_PATH',
    'MNE_DATASETS_EEGBCI_PATH',
    'MNE_DATASETS_HF_SEF_PATH',
    'MNE_DATASETS_MEGSIM_PATH',
    'MNE_DATASETS_MISC_PATH',
    'MNE_DATASETS_MTRF_PATH',
    'MNE_DATASETS_SAMPLE_PATH',
    'MNE_DATASETS_SOMATO_PATH',
    'MNE_DATASETS_MULTIMODAL_PATH',
    'MNE_DATASETS_OPM_PATH',
    'MNE_DATASETS_SPM_FACE_DATASETS_TESTS',
    'MNE_DATASETS_SPM_FACE_PATH',
    'MNE_DATASETS_TESTING_PATH',
    'MNE_DATASETS_VISUAL_92_CATEGORIES_PATH',
    'MNE_DATASETS_KILOWORD_PATH',
    'MNE_DATASETS_FIELDTRIP_CMC_PATH',
    'MNE_DATASETS_PHANTOM_4DBTI_PATH',
    'MNE_FORCE_SERIAL',
    'MNE_KIT2FIFF_STIM_CHANNELS',
    'MNE_KIT2FIFF_STIM_CHANNEL_CODING',
    'MNE_KIT2FIFF_STIM_CHANNEL_SLOPE',
    'MNE_KIT2FIFF_STIM_CHANNEL_THRESHOLD',
    'MNE_LOGGING_LEVEL',
    'MNE_MEMMAP_MIN_SIZE',
    'MNE_SKIP_FTP_TESTS',
    'MNE_SKIP_NETWORK_TESTS',
    'MNE_SKIP_TESTING_DATASET_TESTS',
    'MNE_STIM_CHANNEL',
    'MNE_USE_CUDA',
    'MNE_SKIP_FS_FLASH_CALL',
    'SUBJECTS_DIR',
)

# These allow for partial matches, e.g. 'MNE_STIM_CHANNEL_1' is okay key
known_config_wildcards = (
    'MNE_STIM_CHANNEL',
)


def _load_config(config_path, raise_error=False):
    """Safely load a config file."""
    with open(config_path, 'r') as fid:
        try:
            config = json.load(fid)
        except ValueError:
            # No JSON object could be decoded --> corrupt file?
            msg = ('The MNE-Python config file (%s) is not a valid JSON '
                   'file and might be corrupted' % config_path)
            if raise_error:
                raise RuntimeError(msg)
            warn(msg)
            config = dict()
    return config


def get_config(key=None, default=None, raise_error=False, home_dir=None):
    """Read MNE-Python preferences from environment or config file.

    Parameters
    ----------
    key : None | str
        The preference key to look for. The os environment is searched first,
        then the mne-python config file is parsed.
        If None, all the config parameters present in environment variables or
        the path are returned.
    default : str | None
        Value to return if the key is not found.
    raise_error : bool
        If True, raise an error if the key is not found (instead of returning
        default).
    home_dir : str | None
        The folder that contains the .mne config folder.
        If None, it is found automatically.

    Returns
    -------
    value : dict | str | None
        The preference key value.

    See Also
    --------
    set_config
    """
    _validate_type(key, (string_types, type(None)), "key", 'string or None')

    # first, check to see if key is in env
    if key is not None and key in os.environ:
        return os.environ[key]

    # second, look for it in mne-python config file
    config_path = get_config_path(home_dir=home_dir)
    if not op.isfile(config_path):
        config = {}
    else:
        config = _load_config(config_path)

    if key is None:
        # update config with environment variables
        env_keys = (set(config).union(known_config_types).
                    intersection(os.environ))
        config.update({key: os.environ[key] for key in env_keys})
        return config
    elif raise_error is True and key not in config:
        meth_1 = 'os.environ["%s"] = VALUE' % key
        meth_2 = 'mne.utils.set_config("%s", VALUE, set_env=True)' % key
        raise KeyError('Key "%s" not found in environment or in the '
                       'mne-python config file: %s '
                       'Try either:'
                       ' %s for a temporary solution, or:'
                       ' %s for a permanent one. You can also '
                       'set the environment variable before '
                       'running python.'
                       % (key, config_path, meth_1, meth_2))
    else:
        return config.get(key, default)


def set_config(key, value, home_dir=None, set_env=True):
    """Set a MNE-Python preference key in the config file and environment.

    Parameters
    ----------
    key : str | None
        The preference key to set. If None, a tuple of the valid
        keys is returned, and ``value`` and ``home_dir`` are ignored.
    value : str |  None
        The value to assign to the preference key. If None, the key is
        deleted.
    home_dir : str | None
        The folder that contains the .mne config folder.
        If None, it is found automatically.
    set_env : bool
        If True (default), update :data:`os.environ` in addition to
        updating the MNE-Python config file.

    See Also
    --------
    get_config
    """
    if key is None:
        return known_config_types
    _validate_type(key, 'str', "key")
    # While JSON allow non-string types, we allow users to override config
    # settings using env, which are strings, so we enforce that here
    _validate_type(value, (string_types, type(None)), "value",
                   "None or string")

    if key not in known_config_types and not \
            any(k in key for k in known_config_wildcards):
        warn('Setting non-standard config type: "%s"' % key)

    # Read all previous values
    config_path = get_config_path(home_dir=home_dir)
    if op.isfile(config_path):
        config = _load_config(config_path, raise_error=True)
    else:
        config = dict()
        logger.info('Attempting to create new mne-python configuration '
                    'file:\n%s' % config_path)
    if value is None:
        config.pop(key, None)
        if set_env and key in os.environ:
            del os.environ[key]
    else:
        config[key] = value
        if set_env:
            os.environ[key] = value

    # Write all values. This may fail if the default directory is not
    # writeable.
    directory = op.dirname(config_path)
    if not op.isdir(directory):
        os.mkdir(directory)
    with open(config_path, 'w') as fid:
        json.dump(config, fid, sort_keys=True, indent=0)


class ProgressBar(object):
    """Generate a command-line progressbar.

    Parameters
    ----------
    max_value : int | iterable
        Maximum value of process (e.g. number of samples to process, bytes to
        download, etc.). If an iterable is given, then `max_value` will be set
        to the length of this iterable.
    initial_value : int
        Initial value of process, useful when resuming process from a specific
        value, defaults to 0.
    mesg : str
        Message to include at end of progress bar.
    max_chars : int | str
        Number of characters to use for progress bar itself.
        This does not include characters used for the message or percent
        complete. Can be "auto" (default) to try to set a sane value based
        on the terminal width.
    progress_character : char
        Character in the progress bar that indicates the portion completed.
    spinner : bool
        Show a spinner.  Useful for long-running processes that may not
        increment the progress bar very often.  This provides the user with
        feedback that the progress has not stalled.
    max_total_width : int | str
        Maximum total message width. Can use "auto" (default) to try to set
        a sane value based on the current terminal width.
    verbose_bool : bool
        If True, show progress.

    Example
    -------
    >>> progress = ProgressBar(13000)
    >>> progress.update(3000) # doctest: +SKIP
    [.........                               ] 23.07692 |
    >>> progress.update(6000) # doctest: +SKIP
    [..................                      ] 46.15385 |

    >>> progress = ProgressBar(13000, spinner=True)
    >>> progress.update(3000) # doctest: +SKIP
    [.........                               ] 23.07692 |
    >>> progress.update(6000) # doctest: +SKIP
    [..................                      ] 46.15385 /
    """

    spinner_symbols = ['|', '/', '-', '\\']
    template = '\r[{0}{1}] {2:6.02f}% {4} {3}   '

    def __init__(self, max_value, initial_value=0, mesg='', max_chars='auto',
                 progress_character='.', spinner=False,
                 max_total_width='auto', verbose_bool=True):  # noqa: D102
        self.cur_value = initial_value
        if isinstance(max_value, Iterable):
            self.max_value = len(max_value)
            self.iterable = max_value
        else:
            self.max_value = max_value
            self.iterable = None
        self.mesg = mesg
        self.progress_character = progress_character
        self.spinner = spinner
        self.spinner_index = 0
        self.n_spinner = len(self.spinner_symbols)
        if verbose_bool == 'auto':
            verbose_bool = True if logger.level <= logging.INFO else False
        self._do_print = verbose_bool
        self.cur_time = time.time()
        if max_total_width == 'auto':
            max_total_width = _get_terminal_width()
        self.max_total_width = int(max_total_width)
        if max_chars == 'auto':
            max_chars = min(max(max_total_width - 40, 10), 60)
        self.max_chars = int(max_chars)
        self.cur_rate = 0
        with tempfile.NamedTemporaryFile('wb', prefix='tmp_mne_prog') as tf:
            self._mmap_fname = tf.name
        del tf  # should remove the file
        self._mmap = None

    def update(self, cur_value, mesg=None):
        """Update progressbar with current value of process.

        Parameters
        ----------
        cur_value : number
            Current value of process.  Should be <= max_value (but this is not
            enforced).  The percent of the progressbar will be computed as
            (cur_value / max_value) * 100
        mesg : str
            Message to display to the right of the progressbar.  If None, the
            last message provided will be used.  To clear the current message,
            pass a null string, ''.
        """
        cur_time = time.time()
        cur_rate = ((cur_value - self.cur_value) /
                    max(float(cur_time - self.cur_time), 1e-6))
        # Smooth the estimate a bit
        cur_rate = 0.1 * cur_rate + 0.9 * self.cur_rate
        # Ensure floating-point division so we can get fractions of a percent
        # for the progressbar.
        self.cur_time = cur_time
        self.cur_value = cur_value
        self.cur_rate = cur_rate
        max_value = float(self.max_value) if self.max_value else 1.
        progress = np.clip(self.cur_value / max_value, 0, 1)
        num_chars = int(progress * self.max_chars)
        num_left = self.max_chars - num_chars

        # Update the message
        if mesg is not None:
            if mesg == 'file_sizes':
                mesg = '(%s, %s/s)' % (
                    sizeof_fmt(self.cur_value).rjust(8),
                    sizeof_fmt(cur_rate).rjust(8))
            self.mesg = mesg

        # The \r tells the cursor to return to the beginning of the line rather
        # than starting a new line.  This allows us to have a progressbar-style
        # display in the console window.
        bar = self.template.format(self.progress_character * num_chars,
                                   ' ' * num_left,
                                   progress * 100,
                                   self.spinner_symbols[self.spinner_index],
                                   self.mesg)
        bar = bar[:self.max_total_width]
        # Force a flush because sometimes when using bash scripts and pipes,
        # the output is not printed until after the program exits.
        if self._do_print:
            sys.stdout.write(bar)
            sys.stdout.flush()
        # Increment the spinner
        if self.spinner:
            self.spinner_index = (self.spinner_index + 1) % self.n_spinner

    def update_with_increment_value(self, increment_value, mesg=None):
        """Update progressbar with an increment.

        Parameters
        ----------
        increment_value : int
            Value of the increment of process.  The percent of the progressbar
            will be computed as
            (self.cur_value + increment_value / max_value) * 100
        mesg : str
            Message to display to the right of the progressbar.  If None, the
            last message provided will be used.  To clear the current message,
            pass a null string, ''.
        """
        self.update(self.cur_value + increment_value, mesg)

    def __iter__(self):
        """Iterate to auto-increment the pbar with 1."""
        if self.iterable is None:
            raise ValueError("Must give an iterable to be used in a loop.")
        self.update(self.cur_value)
        for obj in self.iterable:
            yield obj
            self.update_with_increment_value(1)

    def __call__(self, seq):
        """Call the ProgressBar in a joblib-friendly way."""
        while True:
            try:
                yield next(seq)
            except StopIteration:
                return
            else:
                self.update_with_increment_value(1)

    def subset(self, idx):
        """Make a joblib-friendly index subset updater.

        Parameters
        ----------
        idx : ndarray
            List of indices for this subset.

        Returns
        -------
        updater : instance of PBSubsetUpdater
            Class with a ``.update(ii)`` method.
        """
        return _PBSubsetUpdater(self, idx)

    def __setitem__(self, idx, val):
        """Use alternative, mmap-based incrementing (max_value must be int)."""
        if not self._do_print:
            return
        assert val is True
        self._mmap[idx] = True
        self.update(self._mmap.sum())

    def __enter__(self):  # noqa: D105
        if op.isfile(self._mmap_fname):
            os.remove(self._mmap_fname)
        # prevent corner cases where self.max_value == 0
        self._mmap = np.memmap(self._mmap_fname, bool, 'w+',
                               shape=max(self.max_value, 1))
        self.update(0)  # must be zero as we just created the memmap
        return self

    def __exit__(self, type, value, traceback):  # noqa: D105
        """Clean up memmapped file."""
        # we can't put this in __del__ b/c then each worker will delete the
        # file, which is not so good
        self._mmap = None
        if op.isfile(self._mmap_fname):
            os.remove(self._mmap_fname)
        if self._do_print:
            print('')


class _PBSubsetUpdater(object):

    def __init__(self, pb, idx):
        self.pb = pb
        self.idx = idx

    def update(self, ii):
        self.pb[self.idx[:ii]] = True


def _get_terminal_width():
    """Get the terminal width."""
    if sys.version[0] == '2':
        return 80
    else:
        return shutil.get_terminal_size((80, 20)).columns


def _get_http(url, temp_file_name, initial_size, file_size, timeout,
              verbose_bool):
    """Safely (resume a) download to a file from http(s)."""
    # Actually do the reading
    req = urllib.request.Request(url)
    if initial_size > 0:
        req.headers['Range'] = 'bytes=%s-' % (initial_size,)
    try:
        response = urllib.request.urlopen(req, timeout=timeout)
    except Exception:
        # There is a problem that may be due to resuming, some
        # servers may not support the "Range" header. Switch
        # back to complete download method
        logger.info('Resuming download failed (server '
                    'rejected the request). Attempting to '
                    'restart downloading the entire file.')
        del req.headers['Range']
        response = urllib.request.urlopen(req, timeout=timeout)
    total_size = int(response.headers.get('Content-Length', '1').strip())
    if initial_size > 0 and file_size == total_size:
        logger.info('Resuming download failed (resume file size '
                    'mismatch). Attempting to restart downloading the '
                    'entire file.')
        initial_size = 0
    total_size += initial_size
    if total_size != file_size:
        raise RuntimeError('URL could not be parsed properly '
                           '(total size %s != file size %s)'
                           % (total_size, file_size))
    mode = 'ab' if initial_size > 0 else 'wb'
    progress = ProgressBar(total_size, initial_value=initial_size,
                           spinner=True, mesg='file_sizes',
                           verbose_bool=verbose_bool)
    chunk_size = 8192  # 2 ** 13
    with open(temp_file_name, mode) as local_file:
        while True:
            t0 = time.time()
            chunk = response.read(chunk_size)
            dt = time.time() - t0
            if dt < 0.005:
                chunk_size *= 2
            elif dt > 0.1 and chunk_size > 8192:
                chunk_size = chunk_size // 2
            if not chunk:
                if verbose_bool:
                    sys.stdout.write('\n')
                    sys.stdout.flush()
                break
            local_file.write(chunk)
            progress.update_with_increment_value(len(chunk),
                                                 mesg='file_sizes')


def _chunk_write(chunk, local_file, progress):
    """Write a chunk to file and update the progress bar."""
    local_file.write(chunk)
    progress.update_with_increment_value(len(chunk))


@verbose
def _fetch_file(url, file_name, print_destination=True, resume=True,
                hash_=None, timeout=30., verbose=None):
    """Load requested file, downloading it if needed or requested.

    Parameters
    ----------
    url: string
        The url of file to be downloaded.
    file_name: string
        Name, along with the path, of where downloaded file will be saved.
    print_destination: bool, optional
        If true, destination of where file was saved will be printed after
        download finishes.
    resume: bool, optional
        If true, try to resume partially downloaded files.
    hash_ : str | None
        The hash of the file to check. If None, no checking is
        performed.
    timeout : float
        The URL open timeout.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see :func:`mne.verbose`
        and :ref:`Logging documentation <tut_logging>` for more).
    """
    # Adapted from NISL:
    # https://github.com/nisl/tutorial/blob/master/nisl/datasets.py
    if hash_ is not None and (not isinstance(hash_, string_types) or
                              len(hash_) != 32):
        raise ValueError('Bad hash value given, should be a 32-character '
                         'string:\n%s' % (hash_,))
    temp_file_name = file_name + ".part"
    verbose_bool = (logger.level <= 20)  # 20 is info
    try:
        # Check file size and displaying it alongside the download url
        # this loop is necessary to follow any redirects
        for _ in range(10):  # 10 really should be sufficient...
            u = urllib.request.urlopen(url, timeout=timeout)
            try:
                last_url, url = url, u.geturl()
                if url == last_url:
                    file_size = int(
                        u.headers.get('Content-Length', '1').strip())
                    break
            finally:
                u.close()
                del u
        else:
            raise RuntimeError('Too many redirects')
        logger.info('Downloading %s (%s)' % (url, sizeof_fmt(file_size)))

        # Triage resume
        if not os.path.exists(temp_file_name):
            resume = False
        if resume:
            with open(temp_file_name, 'rb', buffering=0) as local_file:
                local_file.seek(0, 2)
                initial_size = local_file.tell()
            del local_file
        else:
            initial_size = 0
        # This should never happen if our functions work properly
        if initial_size > file_size:
            raise RuntimeError('Local file (%s) is larger than remote '
                               'file (%s), cannot resume download'
                               % (sizeof_fmt(initial_size),
                                  sizeof_fmt(file_size)))
        elif initial_size == file_size:
            # This should really only happen when a hash is wrong
            # during dev updating
            warn('Local file appears to be complete (file_size == '
                 'initial_size == %s)' % (file_size,))
        else:
            # Need to resume or start over
            scheme = urllib.parse.urlparse(url).scheme
            if scheme not in ('http', 'https'):
                raise NotImplementedError('Cannot use %s' % (scheme,))
            _get_http(url, temp_file_name, initial_size, file_size, timeout,
                      verbose_bool)

        # check md5sum
        if hash_ is not None:
            logger.info('Verifying hash %s.' % (hash_,))
            md5 = md5sum(temp_file_name)
            if hash_ != md5:
                raise RuntimeError('Hash mismatch for downloaded file %s, '
                                   'expected %s but got %s'
                                   % (temp_file_name, hash_, md5))
        shutil.move(temp_file_name, file_name)
        if print_destination is True:
            logger.info('File saved as %s.\n' % file_name)
    except Exception:
        logger.error('Error while fetching file %s.'
                     ' Dataset fetching aborted.' % url)
        raise


def sizeof_fmt(num):
    """Turn number of bytes into human-readable str.

    Parameters
    ----------
    num : int
        The number of bytes.

    Returns
    -------
    size : str
        The size in human-readable format.
    """
    units = ['bytes', 'kB', 'MB', 'GB', 'TB', 'PB']
    decimals = [0, 0, 1, 2, 2, 2]
    if num > 1:
        exponent = min(int(log(num, 1024)), len(units) - 1)
        quotient = float(num) / 1024 ** exponent
        unit = units[exponent]
        num_decimals = decimals[exponent]
        format_string = '{0:.%sf} {1}' % (num_decimals)
        return format_string.format(quotient, unit)
    if num == 0:
        return '0 bytes'
    if num == 1:
        return '1 byte'


class SizeMixin(object):
    """Estimate MNE object sizes."""

    @property
    def _size(self):
        """Estimate the object size."""
        try:
            size = object_size(self.info)
        except Exception:
            warn('Could not get size for self.info')
            return -1
        if hasattr(self, 'data'):
            size += object_size(self.data)
        elif hasattr(self, '_data'):
            size += object_size(self._data)
        return size

    def __hash__(self):
        """Hash the object.

        Returns
        -------
        hash : int
            The hash
        """
        from .evoked import Evoked
        from .epochs import BaseEpochs
        from .io.base import BaseRaw
        if isinstance(self, Evoked):
            return object_hash(dict(info=self.info, data=self.data))
        elif isinstance(self, (BaseEpochs, BaseRaw)):
            _check_preload(self, "Hashing ")
            return object_hash(dict(info=self.info, data=self._data))
        else:
            raise RuntimeError('Hashing unknown object type: %s' % type(self))


def _url_to_local_path(url, path):
    """Mirror a url path in a local destination (keeping folder structure)."""
    destination = urllib.parse.urlparse(url).path
    # First char should be '/', and it needs to be discarded
    if len(destination) < 2 or destination[0] != '/':
        raise ValueError('Invalid URL')
    destination = os.path.join(path,
                               urllib.request.url2pathname(destination)[1:])
    return destination


def _get_stim_channel(stim_channel, info, raise_error=True):
    """Determine the appropriate stim_channel.

    First, 'MNE_STIM_CHANNEL', 'MNE_STIM_CHANNEL_1', 'MNE_STIM_CHANNEL_2', etc.
    are read. If these are not found, it will fall back to 'STI 014' if
    present, then fall back to the first channel of type 'stim', if present.

    Parameters
    ----------
    stim_channel : str | list of str | None
        The stim channel selected by the user.
    info : instance of Info
        An information structure containing information about the channels.

    Returns
    -------
    stim_channel : str | list of str
        The name of the stim channel(s) to use
    """
    if stim_channel is not None:
        if not isinstance(stim_channel, list):
            _validate_type(stim_channel, 'str', "Stim channel")
            stim_channel = [stim_channel]
        for channel in stim_channel:
            _validate_type(channel, 'str', "Each provided stim channel")
        return stim_channel

    stim_channel = list()
    ch_count = 0
    ch = get_config('MNE_STIM_CHANNEL')
    while(ch is not None and ch in info['ch_names']):
        stim_channel.append(ch)
        ch_count += 1
        ch = get_config('MNE_STIM_CHANNEL_%d' % ch_count)
    if ch_count > 0:
        return stim_channel

    if 'STI101' in info['ch_names']:  # combination channel for newer systems
        return ['STI101']
    if 'STI 014' in info['ch_names']:  # for older systems
        return ['STI 014']

    from .io.pick import pick_types
    stim_channel = pick_types(info, meg=False, ref_meg=False, stim=True)
    if len(stim_channel) > 0:
        stim_channel = [info['ch_names'][ch_] for ch_ in stim_channel]
    elif raise_error:
        raise ValueError("No stim channels found. Consider specifying them "
                         "manually using the 'stim_channel' parameter.")
    return stim_channel


def _check_fname(fname, overwrite=False, must_exist=False):
    """Check for file existence."""
    _validate_type(fname, 'str', 'fname')
    if must_exist and not op.isfile(fname):
        raise IOError('File "%s" does not exist' % fname)
    if op.isfile(fname):
        if not overwrite:
            raise IOError('Destination file exists. Please use option '
                          '"overwrite=True" to force overwriting.')
        elif overwrite != 'read':
            logger.info('Overwriting existing file.')


def _check_subject(class_subject, input_subject, raise_error=True):
    """Get subject name from class."""
    if input_subject is not None:
        _validate_type(input_subject, 'str', "subject input")
        return input_subject
    elif class_subject is not None:
        _validate_type(class_subject, 'str',
                       "Either subject input or class subject attribute")
        return class_subject
    else:
        if raise_error is True:
            raise ValueError('Neither subject input nor class subject '
                             'attribute was a string')
        return None


def _check_preload(inst, msg):
    """Ensure data are preloaded."""
    from .epochs import BaseEpochs
    from .evoked import Evoked
    from .time_frequency import _BaseTFR

    if isinstance(inst, (_BaseTFR, Evoked)):
        pass
    else:
        name = "epochs" if isinstance(inst, BaseEpochs) else 'raw'
        if not inst.preload:
            raise RuntimeError(
                "By default, MNE does not load data into main memory to "
                "conserve resources. " + msg + ' requires %s data to be '
                'loaded. Use preload=True (or string) in the constructor or '
                '%s.load_data().' % (name, name))


def _check_compensation_grade(inst, inst2, name, name2, ch_names=None):
    """Ensure that objects have same compensation_grade."""
    from .io.pick import pick_channels, pick_info
    from .io.compensator import get_current_comp

    if None in [inst.info, inst2.info]:
        return

    if ch_names is None:
        grade = inst.compensation_grade
        grade2 = inst2.compensation_grade
    else:
        info = inst.info.copy()
        info2 = inst2.info.copy()
        # pick channels
        for t_info in [info, info2]:
            if t_info['comps']:
                t_info['comps'] = []
            picks = pick_channels(t_info['ch_names'], ch_names)
            pick_info(t_info, picks, copy=False)
        # get compensation grades
        grade = get_current_comp(info)
        grade2 = get_current_comp(info2)

    # perform check
    if grade != grade2:
        msg = 'Compensation grade of %s (%d) and %s (%d) don\'t match'
        raise RuntimeError(msg % (name, inst.compensation_grade,
                                  name2, inst2.compensation_grade))


def _check_pandas_installed(strict=True):
    """Aux function."""
    try:
        import pandas
        return pandas
    except ImportError:
        if strict is True:
            raise RuntimeError('For this functionality to work, the Pandas '
                               'library is required.')
        else:
            return False


def _check_pandas_index_arguments(index, defaults):
    """Check pandas index arguments."""
    if not any(isinstance(index, k) for k in (list, tuple)):
        index = [index]
    invalid_choices = [e for e in index if e not in defaults]
    if invalid_choices:
        options = [', '.join(e) for e in [invalid_choices, defaults]]
        raise ValueError('[%s] is not an valid option. Valid index'
                         'values are \'None\' or %s' % tuple(options))


def _check_ch_locs(chs):
    """Check if channel locations exist.

    Parameters
    ----------
    chs : dict
        The channels from info['chs']
    """
    locs3d = np.array([ch['loc'][:3] for ch in chs])
    return not ((locs3d == 0).all() or
                (~np.isfinite(locs3d)).all() or
                np.allclose(locs3d, 0.))


def _clean_names(names, remove_whitespace=False, before_dash=True):
    """Remove white-space on topo matching.

    This function handles different naming
    conventions for old VS new VectorView systems (`remove_whitespace`).
    Also it allows to remove system specific parts in CTF channel names
    (`before_dash`).

    Usage
    -----
    # for new VectorView (only inside layout)
    ch_names = _clean_names(epochs.ch_names, remove_whitespace=True)

    # for CTF
    ch_names = _clean_names(epochs.ch_names, before_dash=True)

    """
    cleaned = []
    for name in names:
        if ' ' in name and remove_whitespace:
            name = name.replace(' ', '')
        if '-' in name and before_dash:
            name = name.split('-')[0]
        if name.endswith('_v'):
            name = name[:-2]
        cleaned.append(name)

    return cleaned


def _check_type_picks(picks):
    """Guarantee type integrity of picks."""
    err_msg = 'picks must be None, a list or an array of integers'
    if picks is None:
        pass
    elif isinstance(picks, list):
        for pick in picks:
            _validate_type(pick, 'int', 'Each pick')
        picks = np.array(picks)
    elif isinstance(picks, np.ndarray):
        if not picks.dtype.kind == 'i':
            raise TypeError(err_msg)
    else:
        raise TypeError(err_msg)
    return picks


@nottest
def run_tests_if_main(measure_mem=False):
    """Run tests in a given file if it is run as a script."""
    local_vars = inspect.currentframe().f_back.f_locals
    if not local_vars.get('__name__', '') == '__main__':
        return
    # we are in a "__main__"
    try:
        import faulthandler
        faulthandler.enable()
    except Exception:
        pass
    with warnings.catch_warnings(record=True):  # memory_usage internal dep.
        mem = int(round(max(memory_usage(-1)))) if measure_mem else -1
    if mem >= 0:
        print('Memory consumption after import: %s' % mem)
    t0 = time.time()
    peak_mem, peak_name = mem, 'import'
    max_elapsed, elapsed_name = 0, 'N/A'
    count = 0
    for name in sorted(list(local_vars.keys()), key=lambda x: x.lower()):
        val = local_vars[name]
        if name.startswith('_'):
            continue
        elif callable(val) and name.startswith('test'):
            count += 1
            doc = val.__doc__.strip() if val.__doc__ else name
            sys.stdout.write('%s ... ' % doc)
            sys.stdout.flush()
            try:
                t1 = time.time()
                if measure_mem:
                    with warnings.catch_warnings(record=True):  # dep warn
                        mem = int(round(max(memory_usage((val, (), {})))))
                else:
                    val()
                    mem = -1
                if mem >= peak_mem:
                    peak_mem, peak_name = mem, name
                mem = (', mem: %s MB' % mem) if mem >= 0 else ''
                elapsed = int(round(time.time() - t1))
                if elapsed >= max_elapsed:
                    max_elapsed, elapsed_name = elapsed, name
                sys.stdout.write('time: %0.3f sec%s\n' % (elapsed, mem))
                sys.stdout.flush()
            except Exception as err:
                if 'skiptest' in err.__class__.__name__.lower():
                    sys.stdout.write('SKIP (%s)\n' % str(err))
                    sys.stdout.flush()
                else:
                    raise
    elapsed = int(round(time.time() - t0))
    sys.stdout.write('Total: %s tests\n• %0.3f sec (%0.3f sec for %s)\n• '
                     'Peak memory %s MB (%s)\n'
                     % (count, elapsed, max_elapsed, elapsed_name, peak_mem,
                        peak_name))


class ArgvSetter(object):
    """Temporarily set sys.argv."""

    def __init__(self, args=(), disable_stdout=True,
                 disable_stderr=True):  # noqa: D102
        self.argv = list(('python',) + args)
        self.stdout = StringIO() if disable_stdout else sys.stdout
        self.stderr = StringIO() if disable_stderr else sys.stderr

    def __enter__(self):  # noqa: D105
        self.orig_argv = sys.argv
        sys.argv = self.argv
        self.orig_stdout = sys.stdout
        sys.stdout = self.stdout
        self.orig_stderr = sys.stderr
        sys.stderr = self.stderr
        return self

    def __exit__(self, *args):  # noqa: D105
        sys.argv = self.orig_argv
        sys.stdout = self.orig_stdout
        sys.stderr = self.orig_stderr


class SilenceStdout(object):
    """Silence stdout."""

    def __enter__(self):  # noqa: D105
        self.stdout = sys.stdout
        sys.stdout = StringIO()
        return self

    def __exit__(self, *args):  # noqa: D105
        sys.stdout = self.stdout


def md5sum(fname, block_size=1048576):  # 2 ** 20
    """Calculate the md5sum for a file.

    Parameters
    ----------
    fname : str
        Filename.
    block_size : int
        Block size to use when reading.

    Returns
    -------
    hash_ : str
        The hexadecimal digest of the hash.
    """
    md5 = hashlib.md5()
    with open(fname, 'rb') as fid:
        while True:
            data = fid.read(block_size)
            if not data:
                break
            md5.update(data)
    return md5.hexdigest()


def create_slices(start, stop, step=None, length=1):
    """Generate slices of time indexes.

    Parameters
    ----------
    start : int
        Index where first slice should start.
    stop : int
        Index where last slice should maximally end.
    length : int
        Number of time sample included in a given slice.
    step: int | None
        Number of time samples separating two slices.
        If step = None, step = length.

    Returns
    -------
    slices : list
        List of slice objects.
    """
    # default parameters
    if step is None:
        step = length

    # slicing
    slices = [slice(t, t + length, 1) for t in
              range(start, stop - length + 1, step)]
    return slices


def _time_mask(times, tmin=None, tmax=None, sfreq=None, raise_error=True):
    """Safely find sample boundaries."""
    orig_tmin = tmin
    orig_tmax = tmax
    tmin = -np.inf if tmin is None else tmin
    tmax = np.inf if tmax is None else tmax
    if not np.isfinite(tmin):
        tmin = times[0]
    if not np.isfinite(tmax):
        tmax = times[-1]
    if sfreq is not None:
        # Push to a bit past the nearest sample boundary first
        sfreq = float(sfreq)
        tmin = int(round(tmin * sfreq)) / sfreq - 0.5 / sfreq
        tmax = int(round(tmax * sfreq)) / sfreq + 0.5 / sfreq
    if raise_error and tmin > tmax:
        raise ValueError('tmin (%s) must be less than or equal to tmax (%s)'
                         % (orig_tmin, orig_tmax))
    mask = (times >= tmin)
    mask &= (times <= tmax)
    if raise_error and not mask.any():
        raise ValueError('No samples remain when using tmin=%s and tmax=%s '
                         '(original time bounds are [%s, %s])'
                         % (orig_tmin, orig_tmax, times[0], times[-1]))
    return mask


def random_permutation(n_samples, random_state=None):
    """Emulate the randperm matlab function.

    It returns a vector containing a random permutation of the
    integers between 0 and n_samples-1. It returns the same random numbers
    than randperm matlab function whenever the random_state is the same
    as the matlab's random seed.

    This function is useful for comparing against matlab scripts
    which use the randperm function.

    Note: the randperm(n_samples) matlab function generates a random
    sequence between 1 and n_samples, whereas
    random_permutation(n_samples, random_state) function generates
    a random sequence between 0 and n_samples-1, that is:
    randperm(n_samples) = random_permutation(n_samples, random_state) - 1

    Parameters
    ----------
    n_samples : int
        End point of the sequence to be permuted (excluded, i.e., the end point
        is equal to n_samples-1)
    random_state : int | None
        Random seed for initializing the pseudo-random number generator.

    Returns
    -------
    randperm : ndarray, int
        Randomly permuted sequence between 0 and n-1.
    """
    rng = check_random_state(random_state)
    idx = rng.rand(n_samples)
    randperm = np.argsort(idx)
    return randperm


def compute_corr(x, y):
    """Compute pearson correlations between a vector and a matrix."""
    if len(x) == 0 or len(y) == 0:
        raise ValueError('x or y has zero length')
    X = np.array(x, float)
    Y = np.array(y, float)
    X -= X.mean(0)
    Y -= Y.mean(0)
    x_sd = X.std(0, ddof=1)
    # if covariance matrix is fully expanded, Y needs a
    # transpose / broadcasting else Y is correct
    y_sd = Y.std(0, ddof=1)[:, None if X.shape == Y.shape else Ellipsis]
    return (np.dot(X.T, Y) / float(len(X) - 1)) / (x_sd * y_sd)


def grand_average(all_inst, interpolate_bads=True, drop_bads=True):
    """Make grand average of a list evoked or AverageTFR data.

    For evoked data, the function interpolates bad channels based on
    `interpolate_bads` parameter. If `interpolate_bads` is True, the grand
    average file will contain good channels and the bad channels interpolated
    from the good MEG/EEG channels.
    For AverageTFR data, the function takes the subset of channels not marked
    as bad in any of the instances.

    The grand_average.nave attribute will be equal to the number
    of evoked datasets used to calculate the grand average.

    Note: Grand average evoked should not be used for source localization.

    Parameters
    ----------
    all_inst : list of Evoked or AverageTFR data
        The evoked datasets.
    interpolate_bads : bool
        If True, bad MEG and EEG channels are interpolated. Ignored for
        AverageTFR.
    drop_bads : bool
        If True, drop all bad channels marked as bad in any data set.
        If neither interpolate_bads nor drop_bads is True, in the output file,
        every channel marked as bad in at least one of the input files will be
        marked as bad, but no interpolation or dropping will be performed.

    Returns
    -------
    grand_average : Evoked | AverageTFR
        The grand average data. Same type as input.

    Notes
    -----
    .. versionadded:: 0.11.0
    """
    # check if all elements in the given list are evoked data
    from .evoked import Evoked
    from .time_frequency import AverageTFR
    from .channels.channels import equalize_channels
    assert len(all_inst) > 1
    inst_type = type(all_inst[0])
    _validate_type(all_inst[0], (Evoked, AverageTFR), 'All elements')
    for inst in all_inst:
        _validate_type(inst, inst_type, 'All elements', 'of the same type')

    # Copy channels to leave the original evoked datasets intact.
    all_inst = [inst.copy() for inst in all_inst]

    # Interpolates if necessary
    if isinstance(all_inst[0], Evoked):
        if interpolate_bads:
            all_inst = [inst.interpolate_bads() if len(inst.info['bads']) > 0
                        else inst for inst in all_inst]
        equalize_channels(all_inst)  # apply equalize_channels
        from .evoked import combine_evoked as combine
    else:  # isinstance(all_inst[0], AverageTFR):
        from .time_frequency.tfr import combine_tfr as combine

    if drop_bads:
        bads = list(set((b for inst in all_inst for b in inst.info['bads'])))
        if bads:
            for inst in all_inst:
                inst.drop_channels(bads)

    # make grand_average object using combine_[evoked/tfr]
    grand_average = combine(all_inst, weights='equal')
    # change the grand_average.nave to the number of Evokeds
    grand_average.nave = len(all_inst)
    # change comment field
    grand_average.comment = "Grand average (n = %d)" % grand_average.nave
    return grand_average


def _get_root_dir():
    """Get as close to the repo root as possible."""
    root_dir = op.abspath(op.dirname(__file__))
    up_dir = op.join(root_dir, '..')
    if op.isfile(op.join(up_dir, 'setup.py')) and all(
            op.isdir(op.join(up_dir, x)) for x in ('mne', 'examples', 'doc')):
        root_dir = op.abspath(up_dir)
    return root_dir


def sys_info(fid=None, show_paths=False):
    """Print the system information for debugging.

    This function is useful for printing system information
    to help triage bugs.

    Parameters
    ----------
    fid : file-like | None
        The file to write to. Will be passed to :func:`print()`.
        Can be None to use :data:`sys.stdout`.
    show_paths : bool
        If True, print paths for each module.

    Examples
    --------
    Running this function with no arguments prints an output that is
    useful when submitting bug reports::

        >>> import mne
        >>> mne.sys_info() # doctest: +SKIP
        Platform:      Linux-4.2.0-27-generic-x86_64-with-Ubuntu-15.10-wily
        Python:        2.7.10 (default, Oct 14 2015, 16:09:02)  [GCC 5.2.1 20151010]
        Executable:    /usr/bin/python

        mne:           0.12.dev0
        numpy:         1.12.0.dev0+ec5bd81 {lapack=mkl_rt, blas=mkl_rt}
        scipy:         0.18.0.dev0+3deede3
        matplotlib:    1.5.1+1107.g1fa2697

        sklearn:       0.18.dev0
        nibabel:       2.1.0dev
        mayavi:        4.3.1
        cupy:          4.1.0
        pandas:        0.17.1+25.g547750a
        dipy:          0.14.0

    """  # noqa: E501
    ljust = 15
    out = 'Platform:'.ljust(ljust) + platform.platform() + '\n'
    out += 'Python:'.ljust(ljust) + str(sys.version).replace('\n', ' ') + '\n'
    out += 'Executable:'.ljust(ljust) + sys.executable + '\n'
    out += 'CPU:'.ljust(ljust) + ('%s: %s cores\n' %
                                  (platform.processor(),
                                   multiprocessing.cpu_count()))
    out += 'Memory:'.ljust(ljust)
    try:
        import psutil
    except ImportError:
        out += 'Unavailable (requires "psutil" package)'
    else:
        out += '%0.1f GB\n' % (psutil.virtual_memory().total / float(2 ** 30),)
    out += '\n'
    old_stdout = sys.stdout
    capture = StringIO()
    try:
        sys.stdout = capture
        np.show_config()
    finally:
        sys.stdout = old_stdout
    lines = capture.getvalue().split('\n')
    libs = []
    for li, line in enumerate(lines):
        for key in ('lapack', 'blas'):
            if line.startswith('%s_opt_info' % key):
                lib = lines[li + 1]
                if 'NOT AVAILABLE' in lib:
                    lib = 'unknown'
                else:
                    lib = lib.split('[')[1].split("'")[1]
                libs += ['%s=%s' % (key, lib)]
    libs = ', '.join(libs)
    for mod_name in ('mne', 'numpy', 'scipy', 'matplotlib', '', 'sklearn',
                     'nibabel', 'mayavi', 'cupy', 'pandas', 'dipy'):
        if mod_name == '':
            out += '\n'
            continue
        out += ('%s:' % mod_name).ljust(ljust)
        try:
            mod = __import__(mod_name)
            if mod_name == 'mayavi':
                # the real test
                from mayavi import mlab  # noqa, analysis:ignore
        except Exception:
            out += 'Not found\n'
        else:
            extra = (' (%s)' % op.dirname(mod.__file__)) if show_paths else ''
            if mod_name == 'numpy':
                extra = ' {%s}%s' % (libs, extra)
            elif mod_name == 'matplotlib':
                extra = ' {backend=%s}%s' % (mod.get_backend(), extra)
            elif mod_name == 'mayavi':
                try:
                    from pyface.qt import qt_api
                except Exception:
                    qt_api = 'unknown'
                if qt_api == 'pyqt5':
                    try:
                        from PyQt5.Qt import PYQT_VERSION_STR
                        qt_api += ', PyQt5=%s' % (PYQT_VERSION_STR,)
                    except Exception:
                        pass
                extra = ' {qt_api=%s}%s' % (qt_api, extra)
            out += '%s%s\n' % (mod.__version__, extra)
    print(out, end='', file=fid)


class ETSContext(object):
    """Add more meaningful message to errors generated by ETS Toolkit."""

    def __enter__(self):  # noqa: D105
        pass

    def __exit__(self, type, value, traceback):  # noqa: D105
        if isinstance(value, SystemExit) and value.code.\
                startswith("This program needs access to the screen"):
            value.code += ("\nThis can probably be solved by setting "
                           "ETS_TOOLKIT=qt4. On bash, type\n\n    $ export "
                           "ETS_TOOLKIT=qt4\n\nand run the command again.")


def open_docs(kind=None, version=None):
    """Launch a new web browser tab with the MNE documentation.

    Parameters
    ----------
    kind : str | None
        Can be "api" (default), "tutorials", or "examples".
        The default can be changed by setting the configuration value
        MNE_DOCS_KIND.
    version : str | None
        Can be "stable" (default) or "dev".
        The default can be changed by setting the configuration value
        MNE_DOCS_VERSION.
    """
    if kind is None:
        kind = get_config('MNE_DOCS_KIND', 'api')
    help_dict = dict(api='python_reference.html', tutorials='tutorials.html',
                     examples='auto_examples/index.html')
    if kind not in help_dict:
        raise ValueError('kind must be one of %s, got %s'
                         % (sorted(help_dict.keys()), kind))
    kind = help_dict[kind]
    if version is None:
        version = get_config('MNE_DOCS_VERSION', 'stable')
    versions = ('stable', 'dev')
    if version not in versions:
        raise ValueError('version must be one of %s, got %s'
                         % (version, versions))
    webbrowser.open_new_tab('https://martinos.org/mne/%s/%s' % (version, kind))


def _is_numeric(n):
    return isinstance(n, (np.integer, np.floating, int, float))


def _validate_type(item, types=None, item_name=None, type_name=None):
    """Validate that `item` is an instance of `types`.

    Parameters
    ----------
    item : obj
        The thing to be checked.
    types : type | tuple of types | str
         The types to be checked against. If str, must be one of 'str', 'int',
         'numeric'.
    """
    if types == "int":
        _ensure_int(item, name=item_name)
        return  # terminate prematurely
    elif types == "str":
        types = string_types
        type_name = "str" if type_name is None else type_name
    elif types == "numeric":
        types = (np.integer, np.floating, int, float)
        type_name = "numeric" if type_name is None else type_name
    elif types == "info":
        from mne.io import Info as types
        type_name = "Info" if type_name is None else type_name
        item_name = "Info" if item_name is None else item_name

    if type_name is None:
        iter_types = ([types] if not isinstance(types, (list, tuple))
                      else types)
        type_name = ', '.join(cls.__name__ for cls in iter_types)
    if not isinstance(item, types):
        raise TypeError('%s must be an instance of %s, got %s instead'
                        % (item_name, type_name, type(item),))


def linkcode_resolve(domain, info):
    """Determine the URL corresponding to a Python object.

    Parameters
    ----------
    domain : str
        Only useful when 'py'.
    info : dict
        With keys "module" and "fullname".

    Returns
    -------
    url : str
        The code URL.

    Notes
    -----
    This has been adapted to deal with our "verbose" decorator.

    Adapted from SciPy (doc/source/conf.py).
    """
    import mne
    if domain != 'py':
        return None

    modname = info['module']
    fullname = info['fullname']

    submod = sys.modules.get(modname)
    if submod is None:
        return None

    obj = submod
    for part in fullname.split('.'):
        try:
            obj = getattr(obj, part)
        except Exception:
            return None

    try:
        fn = inspect.getsourcefile(obj)
    except Exception:
        fn = None
    if not fn:
        try:
            fn = inspect.getsourcefile(sys.modules[obj.__module__])
        except Exception:
            fn = None
    if not fn:
        return None
    if fn == '<string>':  # verbose decorator
        fn = inspect.getmodule(obj).__file__
    fn = op.relpath(fn, start=op.dirname(mne.__file__))
    fn = '/'.join(op.normpath(fn).split(os.sep))  # in case on Windows

    try:
        source, lineno = inspect.getsourcelines(obj)
    except Exception:
        lineno = None

    if lineno:
        linespec = "#L%d-L%d" % (lineno, lineno + len(source) - 1)
    else:
        linespec = ""

    if 'dev' in mne.__version__:
        kind = 'master'
    else:
        kind = 'maint/%s' % ('.'.join(mne.__version__.split('.')[:2]))
    return "http://github.com/mne-tools/mne-python/blob/%s/mne/%s%s" % (  # noqa
       kind, fn, linespec)


def _check_if_nan(data, msg=" to be plotted"):
    """Raise if any of the values are NaN."""
    if not np.isfinite(data).all():
        raise ValueError("Some of the values {} are NaN.".format(msg))
