"""Test check utilities."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import os
import sys
from pathlib import Path

import numpy as np
import pytest

import mne
from mne import pick_channels_cov, read_vectorview_selection
from mne._fiff.pick import _picks_to_idx
from mne.datasets import testing
from mne.utils import (
    Bunch,
    _check_ch_locs,
    _check_fname,
    _check_info_inv,
    _check_option,
    _check_range,
    _check_sphere,
    _check_subject,
    _on_missing,
    _path_like,
    _record_warnings,
    _safe_input,
    _soft_import,
    _suggest,
    _validate_type,
    catch_logging,
    check_fname,
    check_random_state,
    check_version,
)

data_path = testing.data_path(download=False)
base_dir = data_path / "MEG" / "sample"
fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif"
fname_event = base_dir / "sample_audvis_trunc_raw-eve.fif"
fname_fwd = base_dir / "sample_audvis_trunc-meg-vol-7-fwd.fif"
fname_mgz = data_path / "subjects" / "sample" / "mri" / "aseg.mgz"
reject = dict(grad=4000e-13, mag=4e-12)


@testing.requires_testing_data
def test_check(tmp_path):
    """Test checking functions."""
    pytest.raises(ValueError, check_random_state, "foo")
    pytest.raises(TypeError, _check_fname, 1)
    _check_fname(Path("./foo"))
    fname = tmp_path / "foo"
    with open(fname, "wb"):
        pass
    assert fname.is_file()
    _check_fname(fname, overwrite="read", must_exist=True)
    orig_perms = os.stat(fname).st_mode
    os.chmod(fname, 0)
    if not sys.platform.startswith("win"):
        with pytest.raises(PermissionError, match="read permissions"):
            _check_fname(fname, overwrite="read", must_exist=True)
    os.chmod(fname, orig_perms)
    os.remove(fname)
    assert not fname.is_file()
    pytest.raises(OSError, check_fname, "foo", "tets-dip.x", (), (".fif",))
    pytest.raises(ValueError, _check_subject, None, None)
    pytest.raises(TypeError, _check_subject, None, 1)
    pytest.raises(TypeError, _check_subject, 1, None)
    # smoke tests for permitted types
    check_random_state(None).choice(1)
    check_random_state(0).choice(1)
    check_random_state(np.random.RandomState(0)).choice(1)
    check_random_state(np.random.default_rng(0)).choice(1)


@testing.requires_testing_data
@pytest.mark.parametrize(
    "suffix",
    ("_meg.fif", "_eeg.fif", "_ieeg.fif", "_meg.fif.gz", "_eeg.fif.gz", "_ieeg.fif.gz"),
)
def test_check_fname_suffixes(suffix, tmp_path):
    """Test checking for valid filename suffixes."""
    new_fname = tmp_path / fname_raw.name.replace("_raw.fif", suffix)
    raw = mne.io.read_raw_fif(fname_raw).crop(0, 0.1)
    raw.save(new_fname)
    mne.io.read_raw_fif(new_fname)


def _get_data():
    """Read in data used in tests."""
    # read forward model
    forward = mne.read_forward_solution(fname_fwd)
    # read data
    raw = mne.io.read_raw_fif(fname_raw, preload=True)
    events = mne.read_events(fname_event)
    event_id, tmin, tmax = 1, -0.1, 0.15

    # decimate for speed
    left_temporal_channels = read_vectorview_selection("Left-temporal")
    picks = mne.pick_types(raw.info, meg=True, selection=left_temporal_channels)
    picks = picks[::2]
    raw.pick([raw.ch_names[ii] for ii in picks])
    del picks

    raw.info.normalize_proj()  # avoid projection warnings

    epochs = mne.Epochs(
        raw,
        events,
        event_id,
        tmin,
        tmax,
        proj=True,
        baseline=(None, 0.0),
        preload=True,
        reject=reject,
    )

    noise_cov = mne.compute_covariance(epochs, tmin=None, tmax=0.0)

    data_cov = mne.compute_covariance(epochs, tmin=0.01, tmax=0.15)

    return epochs, data_cov, noise_cov, forward


@testing.requires_testing_data
def test_check_info_inv():
    """Test checks for common channels across fwd model and cov matrices."""
    epochs, data_cov, noise_cov, forward = _get_data()

    # make sure same channel lists exist in data to make testing life easier
    assert epochs.info["ch_names"] == data_cov.ch_names
    assert epochs.info["ch_names"] == noise_cov.ch_names

    # check whether bad channels get excluded from the channel selection
    # info
    info_bads = epochs.info.copy()
    info_bads["bads"] = info_bads["ch_names"][1:3]  # include two bad channels
    picks = _check_info_inv(info_bads, forward, noise_cov=noise_cov)
    assert [1, 2] not in picks
    # covariance matrix
    data_cov_bads = data_cov.copy()
    data_cov_bads["bads"] = [data_cov_bads.ch_names[0]]
    picks = _check_info_inv(epochs.info, forward, data_cov=data_cov_bads)
    assert 0 not in picks
    # noise covariance matrix
    noise_cov_bads = noise_cov.copy()
    noise_cov_bads["bads"] = [noise_cov_bads.ch_names[1]]
    picks = _check_info_inv(epochs.info, forward, noise_cov=noise_cov_bads)
    assert 1 not in picks

    # test whether reference channels get deleted
    info_ref = epochs.info.copy()
    info_ref["chs"][0]["kind"] = 301  # pretend to have a ref channel
    picks = _check_info_inv(info_ref, forward, noise_cov=noise_cov)
    assert 0 not in picks

    # pick channels in all inputs and make sure common set is returned
    epochs.pick([epochs.ch_names[ii] for ii in range(10)])
    data_cov = pick_channels_cov(
        data_cov, include=[data_cov.ch_names[ii] for ii in range(5, 20)]
    )
    noise_cov = pick_channels_cov(
        noise_cov, include=[noise_cov.ch_names[ii] for ii in range(7, 12)]
    )
    with catch_logging() as log:
        picks = _check_info_inv(
            epochs.info, forward, noise_cov=noise_cov, data_cov=data_cov, verbose=True
        )
        assert list(range(7, 10)) == picks

    # make sure to inform the user that 7 channels were dropped
    # (there are 10 channels in epochs but only 3 were picked)
    log = log.getvalue()
    assert "Excluding 7 channel(s) missing" in log


def test_check_option():
    """Test checking the value of a parameter against a list of options."""
    allowed_values = ["valid", "good", "ok"]

    # Value is allowed
    assert _check_option("option", "valid", allowed_values)
    assert _check_option("option", "good", allowed_values)
    assert _check_option("option", "ok", allowed_values)
    assert _check_option("option", "valid", ["valid"])

    # Check error message for invalid value
    msg = (
        "Invalid value for the 'option' parameter. Allowed values are "
        "'valid', 'good', and 'ok', but got 'bad' instead."
    )
    with pytest.raises(ValueError, match=msg):
        assert _check_option("option", "bad", allowed_values)

    # Special error message if only one value is allowed
    msg = (
        "Invalid value for the 'option' parameter. The only allowed value "
        "is 'valid', but got 'bad' instead."
    )
    with pytest.raises(ValueError, match=msg):
        assert _check_option("option", "bad", ["valid"])


def test_path_like():
    """Test _path_like()."""
    str_path = str(base_dir)
    pathlib_path = Path(base_dir)
    no_path = dict(foo="bar")

    assert _path_like(str_path) is True
    assert _path_like(pathlib_path) is True
    assert _path_like(no_path) is False


def test_validate_type():
    """Test _validate_type."""
    _validate_type(1, "int-like")
    with pytest.raises(TypeError, match="int-like"):
        _validate_type(False, "int-like")
    _validate_type([1, 2, 3], "array-like")
    _validate_type((1, 2, 3), "array-like")
    _validate_type({1, 2, 3}, "array-like")
    with pytest.raises(TypeError, match="array-like"):
        _validate_type("123", "array-like")  # a string is not array-like


def test_check_range():
    """Test _check_range."""
    _check_range(10, 1, 100, "value")
    with pytest.raises(ValueError, match="must be between"):
        _check_range(0, 1, 10, "value")
    with pytest.raises(ValueError, match="must be between"):
        _check_range(1, 1, 10, "value", False, False)


@testing.requires_testing_data
def test_suggest():
    """Test suggestions."""
    pytest.importorskip("nibabel")
    names = mne.get_volume_labels_from_aseg(fname_mgz)
    sug = _suggest("", names)
    assert sug == ""  # nothing
    sug = _suggest("Left-cerebellum", names)
    assert sug == " Did you mean 'Left-Cerebellum-Cortex'?"
    sug = _suggest("Cerebellum-Cortex", names)
    assert (
        sug
        == " Did you mean one of ['Left-Cerebellum-Cortex', 'Right-Cerebellum-Cortex', 'Left-Cerebral-Cortex']?"  # noqa: E501
    )


def test_on_missing():
    """Test _on_missing."""
    msg = "test"
    with pytest.raises(ValueError, match=msg):
        _on_missing("raise", msg)
    with pytest.warns(RuntimeWarning, match=msg):
        _on_missing("warn", msg)
    _on_missing("ignore", msg)

    with pytest.raises(
        ValueError, match="Invalid value for the 'on_missing' parameter"
    ):
        _on_missing("foo", msg)


def _matlab_input(msg):
    raise EOFError()


def test_safe_input(monkeypatch):
    """Test _safe_input."""
    monkeypatch.setattr(mne.utils.check, "input", _matlab_input)
    with pytest.raises(RuntimeError, match="Could not use input"):
        _safe_input("whatever", alt="nothing")
    assert _safe_input("whatever", use="nothing") == "nothing"


@testing.requires_testing_data
def test_check_ch_locs():
    """Test _check_ch_locs behavior."""
    info = mne.io.read_info(fname_raw)
    assert _check_ch_locs(info=info)

    for picks in ([0], [0, 1], None):
        assert _check_ch_locs(info=info, picks=picks)

    for ch_type in ("meg", "mag", "grad", "eeg"):
        assert _check_ch_locs(info=info, ch_type=ch_type)

    # drop locations for EEG
    picks_eeg = _picks_to_idx(info=info, picks="eeg")
    for idx in picks_eeg:
        info["chs"][idx]["loc"][:3] = np.nan

    # EEG tests should fail now
    assert _check_ch_locs(info=info, picks=picks_eeg) is False
    assert _check_ch_locs(info=info, ch_type="eeg") is False

    # tests for other (and "all") channels should still pass
    assert _check_ch_locs(info=info)
    assert _check_ch_locs(info=info, ch_type="mag")


# Check a bunch of version schemes as of 2022/03/01
# We don't have to get this 100% generalized, but it would be nice if all
# of these worked.
@pytest.mark.parametrize(
    "version, want, have_unstripped",
    [
        # test some dev cases
        ("1.23.0.dev0+782.g1168868df6", "1.23", False),  # NumPy
        ("1.9.0.dev0+1485.b06254e", "1.9", False),  # SciPy
        ("3.6.0.dev1651+g30d6161406", "3.6", False),  # matplotlib
        ("1.1.dev0", "1.1", False),  # sklearn
        ("0.56.0dev0+39.gef1ba4c10", "0.56", False),  # numba
        ("9.1.0.rc1", "9.1", False),  # VTK
        ("0.3dev0", "0.3", False),  # mne-connectivity
        ("0.2.2.dev0", "0.2.2", False),  # mne-qt-browser
        ("3.2.2+150.g1e93bd5d", "3.2.2", True),  # nibabel
        # test some stable cases
        ("1.2.3", "1.2.3", True),
        ("1.2", "1.2", True),
        ("1", "1", True),
    ],
)
def test_strip_dev(version, want, have_unstripped, monkeypatch):
    """Test that stripping dev works."""
    monkeypatch.setattr(
        mne.utils.check, "import_module", lambda x: Bunch(__version__=version)
    )
    got_have_unstripped, same_version = check_version(
        version, want, strip=False, return_version=True
    )
    assert same_version == version
    assert got_have_unstripped is have_unstripped
    have, simpler_version = check_version(
        "foo", want, return_version=True
    )  # strip=True is the default
    assert have, (simpler_version, version)

    def looks_stable(version):
        try:
            [int(x) for x in version.split(".")]
        except ValueError:
            return False
        else:
            return True

    if looks_stable(version):
        assert "dev" not in version
        assert "rc" not in version
        assert simpler_version == version
    else:
        assert simpler_version != version
    assert "dev" not in simpler_version
    assert "rc" not in simpler_version
    assert not simpler_version.endswith(".")
    assert looks_stable(simpler_version)


@testing.requires_testing_data
def test_check_sphere_verbose():
    """Test that verbose is handled properly in _check_sphere."""
    info = mne.io.read_info(fname_raw)
    with info._unlock():
        info["dig"] = info["dig"][:20]
    with _record_warnings(), pytest.warns(RuntimeWarning, match="may be inaccurate"):
        _check_sphere("auto", info)
    with mne.use_log_level("error"):
        _check_sphere("auto", info)


def test_soft_import():
    """Test _soft_import."""
    with pytest.raises(RuntimeError, match=r".* the module mne>=999 \(found version.*"):
        _soft_import("mne", "testing", min_version="999")
