# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import itertools
from pathlib import Path

import numpy as np
import pytest

from mne import Epochs, create_info, io, pick_types, read_events
from mne.channels import make_standard_montage
from mne.preprocessing import equalize_bads, interpolate_bridged_electrodes
from mne.preprocessing.interpolate import _find_centroid_sphere
from mne.transforms import _cart_to_sph

base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
event_name = base_dir / "test-eve.fif"
raw_fname_ctf = base_dir / "test_ctf_raw.fif"
event_id, tmin, tmax = 1, -0.2, 0.5
event_id_2 = 2


def _load_data():
    """Load data."""
    # It is more memory efficient to load data in a separate
    # function so it's loaded on-demand
    raw = io.read_raw_fif(raw_fname).pick(["eeg", "stim"])
    events = read_events(event_name)
    # subselect channels for speed
    picks = pick_types(raw.info, meg=False, eeg=True, exclude=[])[:15]
    epochs = Epochs(
        raw,
        events,
        event_id,
        tmin,
        tmax,
        picks=picks,
        preload=True,
        reject=dict(eeg=80e-6),
    )
    evoked = epochs.average()
    return raw.load_data(), epochs.load_data(), evoked


@pytest.mark.parametrize("interp_thresh", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("inst_type", ["raw", "epochs", "evoked"])
def test_equalize_bads(interp_thresh, inst_type):
    """Test equalize_bads function."""
    raw, epochs, evoked = _load_data()

    if inst_type == "raw":
        insts = [raw.copy().crop(0, 1), raw.copy().crop(0, 2)]
    elif inst_type == "epochs":
        insts = [epochs.copy()[:1], epochs.copy()[:2]]
    else:
        insts = [evoked.copy().crop(0, 0.1), raw.copy().crop(0, 0.2)]

    with pytest.raises(ValueError, match="between 0"):
        equalize_bads(insts, interp_thresh=2.0)

    bads = insts[0].copy().pick("eeg").ch_names[:3]
    insts[0].info["bads"] = bads[:2]
    insts[1].info["bads"] = bads[1:]

    insts_ok = equalize_bads(insts, interp_thresh=interp_thresh)
    if interp_thresh == 0:
        bads_ok = []
    elif interp_thresh == 1:
        bads_ok = bads
    else:  # interp_thresh == 0.5
        bads_ok = bads[1:]

    for inst in insts_ok:
        assert set(inst.info["bads"]) == set(bads_ok)


def test_interpolate_bridged_electrodes():
    """Test interpolate_bridged_electrodes function."""
    raw, epochs, evoked = _load_data()
    for inst in (raw, epochs, evoked):
        idx0 = inst.ch_names.index("EEG 001")
        idx1 = inst.ch_names.index("EEG 002")
        ch_names_orig = inst.ch_names.copy()
        bads_orig = inst.info["bads"].copy()
        inst2 = inst.copy()
        inst2.info["bads"] = ["EEG 001", "EEG 002"]
        inst2.interpolate_bads()
        data_interp_reg = inst2.get_data(picks=["EEG 001", "EEG 002"])
        inst = interpolate_bridged_electrodes(inst, [(idx0, idx1)])
        data_interp = inst.get_data(picks=["EEG 001", "EEG 002"])
        assert not any(["virtual" in ch for ch in inst.ch_names])
        assert inst.ch_names == ch_names_orig
        assert inst.info["bads"] == bads_orig
        # check closer to regular interpolation than original data
        assert 1e-6 < np.mean(np.abs(data_interp - data_interp_reg)) < 5.4e-5

    for inst in (raw, epochs, evoked):
        idx0 = inst.ch_names.index("EEG 001")
        idx1 = inst.ch_names.index("EEG 002")
        idx2 = inst.ch_names.index("EEG 003")
        ch_names_orig = inst.ch_names.copy()
        bads_orig = inst.info["bads"].copy()
        inst2 = inst.copy()
        inst2.info["bads"] = ["EEG 001", "EEG 002", "EEG 003"]
        inst2.interpolate_bads()
        data_interp_reg = inst2.get_data(picks=["EEG 001", "EEG 002", "EEG 003"])
        inst = interpolate_bridged_electrodes(
            inst, [(idx0, idx1), (idx0, idx2), (idx1, idx2)]
        )
        data_interp = inst.get_data(picks=["EEG 001", "EEG 002", "EEG 003"])
        assert not any(["virtual" in ch for ch in inst.ch_names])
        assert inst.ch_names == ch_names_orig
        assert inst.info["bads"] == bads_orig
        # check closer to regular interpolation than original data
        assert 1e-6 < np.mean(np.abs(data_interp - data_interp_reg)) < 5.4e-5

    # test bad_limit
    montage = make_standard_montage("standard_1020")
    ch_names = [
        ch
        for ch in montage.ch_names
        if ch not in ["P7", "P8", "T3", "T4", "T5", "T4", "T6"]
    ]
    info = create_info(ch_names, sfreq=1024, ch_types="eeg")
    data = np.random.randn(len(ch_names), 1024)
    data[:5, :] = np.ones((5, 1024))
    raw = io.RawArray(data, info)
    raw.set_montage("standard_1020")
    bridged_idx = list(itertools.combinations(range(5), 2))
    with pytest.raises(
        RuntimeError,
        match="The channels Fp1, Fpz, Fp2, AF9, AF7 are bridged "
        "together and form a large area of bridged electrodes.",
    ):
        interpolate_bridged_electrodes(raw, bridged_idx, bad_limit=4)
    # increase the limit to prevent raising
    interpolate_bridged_electrodes(raw, bridged_idx, bad_limit=5)
    # invalid argument
    with pytest.raises(
        ValueError, match="Argument 'bad_limit' should be a strictly positive integer."
    ):
        interpolate_bridged_electrodes(raw, bridged_idx, bad_limit=-4)


def test_find_centroid():
    """Test that the centroid is correct."""
    montage = make_standard_montage("standard_1020")
    ch_names = [
        ch
        for ch in montage.ch_names
        if ch not in ["P7", "P8", "T3", "T4", "T5", "T4", "T6"]
    ]
    info = create_info(ch_names, sfreq=1024, ch_types="eeg")
    info.set_montage(montage)
    montage = info.get_montage()
    pos = montage.get_positions()
    assert pos["coord_frame"] == "head"

    # look for centroid between T7 and TP7, an average in spehrical coordinate
    # fails and places the average on the wrong side of the head between T8 and
    # TP8
    ch_names = ["T7", "TP7"]
    pos_centroid = _find_centroid_sphere(pos["ch_pos"], ch_names)
    _check_centroid_position(pos, ch_names, pos_centroid)

    # check other positions
    pairs = [
        ("CPz", "CP2"),
        ("CPz", "Cz"),
        ("Fpz", "AFz"),
        ("AF7", "F7"),
        ("O1", "O2"),
        ("M2", "A2"),
        ("P5", "P9"),
    ]
    for ch_names in pairs:
        pos_centroid = _find_centroid_sphere(pos["ch_pos"], ch_names)
        _check_centroid_position(pos, ch_names, pos_centroid)
    triplets = [("CPz", "Cz", "FCz"), ("AF9", "Fpz", "AF10"), ("FT10", "FT8", "T10")]
    for ch_names in triplets:
        pos_centroid = _find_centroid_sphere(pos["ch_pos"], ch_names)
        _check_centroid_position(pos, ch_names, pos_centroid)


def _check_centroid_position(pos, ch_names, pos_centroid):
    """Check the centroid distance.

    The cartesian average should be distanced from pos_centroid by the
    difference between the radii.
    """
    radii = list()
    cartesian_positions = np.zeros((len(ch_names), 3))
    for i, ch in enumerate(ch_names):
        radii.append(_cart_to_sph(pos["ch_pos"][ch])[0, 0])
        cartesian_positions[i, :] = pos["ch_pos"][ch]
    avg_radius = np.average(radii)
    avg_cartesian_position = np.average(cartesian_positions, axis=0)
    avg_cartesian_position_radius = _cart_to_sph(avg_cartesian_position)[0, 0]
    radius_diff = np.abs(avg_radius - avg_cartesian_position_radius)
    # distance
    distance = np.linalg.norm(pos_centroid - avg_cartesian_position)
    assert np.isclose(radius_diff, distance, atol=1e-6)
