# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from pathlib import Path

import numpy as np
import pytest
from numpy.testing import (
    assert_allclose,
    assert_array_almost_equal,
    assert_array_equal,
    assert_equal,
)

pytest.importorskip("sklearn")

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.svm import SVC

from mne import Epochs, compute_proj_raw, io, pick_types, read_events
from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef
from mne.decoding.csp import _ajd_pham
from mne.utils import catch_logging

data_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = data_dir / "test_raw.fif"
event_name = data_dir / "test-eve.fif"
tmin, tmax = -0.1, 0.2
event_id = dict(aud_l=1, vis_l=3)
# if stop is too small pca may fail in some cases, but we're okay on this file
start, stop = 0, 8


def simulate_data(target, n_trials=100, n_channels=10, random_state=42):
    """Simulate data according to an instantaneous mixin model.

    Data are simulated in the statistical source space, where one source is
    modulated according to a target variable, before being mixed with a
    random mixing matrix.
    """
    rs = np.random.RandomState(random_state)

    # generate a orthogonal mixin matrix
    mixing_mat = np.linalg.svd(rs.randn(n_channels, n_channels))[0]

    S = rs.randn(n_trials, n_channels, 50)
    S[:, 0] *= np.atleast_2d(np.sqrt(target)).T
    S[:, 1:] *= 0.01  # less noise

    X = np.dot(mixing_mat, S).transpose((1, 0, 2))

    return X, mixing_mat


def deterministic_toy_data(classes=("class_a", "class_b")):
    """Generate a small deterministic toy data set.

    Four independent sources are modulated by the target class and mixed
    into signal space.
    """
    sources_a = (
        np.array(
            [
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
                [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
                [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
            ],
            dtype=float,
        )
        * 2
        - 1
    )

    sources_b = (
        np.array(
            [
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
                [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
                [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
            ],
            dtype=float,
        )
        * 2
        - 1
    )

    sources_a[0, :] *= 1
    sources_a[1, :] *= 2

    sources_b[2, :] *= 3
    sources_b[3, :] *= 4

    mixing = np.array(
        [
            [1.0, 0.8, 0.6, 0.4],
            [0.8, 1.0, 0.8, 0.6],
            [0.6, 0.8, 1.0, 0.8],
            [0.4, 0.6, 0.8, 1.0],
        ]
    )

    x_class_a = mixing @ sources_a
    x_class_b = mixing @ sources_b

    x = np.stack([x_class_a, x_class_b])
    y = np.array(classes)

    return x, y


@pytest.mark.slowtest
def test_csp():
    """Test Common Spatial Patterns algorithm on epochs."""
    raw = io.read_raw_fif(raw_fname, preload=False)
    events = read_events(event_name)
    picks = pick_types(
        raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads"
    )
    picks = picks[2:12:3]  # subselect channels -> disable proj!
    raw.add_proj([], remove_existing=True)
    epochs = Epochs(
        raw,
        events,
        event_id,
        tmin,
        tmax,
        picks=picks,
        baseline=(None, 0),
        preload=True,
        proj=False,
    )
    epochs_data = epochs.get_data(copy=False)
    n_channels = epochs_data.shape[1]
    y = epochs.events[:, -1]

    # Init
    pytest.raises(ValueError, CSP, n_components="foo", norm_trace=False)
    for reg in ["foo", -0.1, 1.1]:
        csp = CSP(reg=reg, norm_trace=False)
        pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1])
    for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]:
        CSP(reg=reg, norm_trace=False)
    for cov_est in ["foo", None]:
        pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False)
    with pytest.raises(TypeError, match="instance of bool"):
        CSP(norm_trace="foo")
    for cov_est in ["concat", "epoch"]:
        CSP(cov_est=cov_est, norm_trace=False)

    n_components = 3
    # Fit
    for norm_trace in [True, False]:
        csp = CSP(n_components=n_components, norm_trace=norm_trace)
        csp.fit(epochs_data, epochs.events[:, -1])

    assert_equal(len(csp.mean_), n_components)
    assert_equal(len(csp.std_), n_components)

    # Transform
    X = csp.fit_transform(epochs_data, y)
    sources = csp.transform(epochs_data)
    assert sources.shape[1] == n_components
    assert csp.filters_.shape == (n_channels, n_channels)
    assert csp.patterns_.shape == (n_channels, n_channels)
    assert_array_almost_equal(sources, X)

    # Test data exception
    pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events))
    pytest.raises(ValueError, csp.fit, epochs, y)
    pytest.raises(ValueError, csp.transform, epochs)

    # Test plots
    epochs.pick(picks="mag")
    cmap = ("RdBu", True)
    components = np.arange(n_components)
    for plot in (csp.plot_patterns, csp.plot_filters):
        plot(epochs.info, components=components, res=12, show=False, cmap=cmap)

    # Test with more than 2 classes
    epochs = Epochs(
        raw,
        events,
        tmin=tmin,
        tmax=tmax,
        picks=picks,
        event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4),
        baseline=(None, 0),
        proj=False,
        preload=True,
    )
    epochs_data = epochs.get_data(copy=False)
    n_channels = epochs_data.shape[1]

    n_channels = epochs_data.shape[1]
    for cov_est in ["concat", "epoch"]:
        csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False)
        csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data)
        assert_equal(len(csp._classes), 4)
        assert_array_equal(csp.filters_.shape, [n_channels, n_channels])
        assert_array_equal(csp.patterns_.shape, [n_channels, n_channels])

    # Test average power transform
    n_components = 2
    assert csp.transform_into == "average_power"
    feature_shape = [len(epochs_data), n_components]
    X_trans = dict()
    for log in (None, True, False):
        csp = CSP(n_components=n_components, log=log, norm_trace=False)
        assert csp.log is log
        Xt = csp.fit_transform(epochs_data, epochs.events[:, 2])
        assert_array_equal(Xt.shape, feature_shape)
        X_trans[str(log)] = Xt
    # log=None => log=True
    assert_array_almost_equal(X_trans["None"], X_trans["True"])
    # Different normalization return different transform
    assert np.sum((X_trans["True"] - X_trans["False"]) ** 2) > 1.0
    # Check wrong inputs
    pytest.raises(ValueError, CSP, transform_into="average_power", log="foo")

    # Test csp space transform
    csp = CSP(transform_into="csp_space", norm_trace=False)
    assert csp.transform_into == "csp_space"
    for log in ("foo", True, False):
        pytest.raises(
            ValueError, CSP, transform_into="csp_space", log=log, norm_trace=False
        )
    n_components = 2
    csp = CSP(n_components=n_components, transform_into="csp_space", norm_trace=False)
    Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data)
    feature_shape = [len(epochs_data), n_components, epochs_data.shape[2]]
    assert_array_equal(Xt.shape, feature_shape)

    # Check mixing matrix on simulated data
    y = np.array([100] * 50 + [1] * 50)
    X, A = simulate_data(y)

    for cov_est in ["concat", "epoch"]:
        # fit csp
        csp = CSP(n_components=1, cov_est=cov_est, norm_trace=False)
        csp.fit(X, y)

        # check the first pattern match the mixing matrix
        # the sign might change
        corr = np.abs(np.corrcoef(csp.patterns_[0, :].T, A[:, 0])[0, 1])
        assert np.abs(corr) > 0.99

        # check output
        out = csp.transform(X)
        corr = np.abs(np.corrcoef(out[:, 0], y)[0, 1])
        assert np.abs(corr) > 0.95


# Even the "reg is None and rank is None" case should pass now thanks to the
# do_compute_rank
@pytest.mark.parametrize("ch_type", ("mag", "eeg", ("mag", "eeg")))
@pytest.mark.parametrize("rank", (None, "full", "correct"))
@pytest.mark.parametrize("reg", [None, 0.001, "oas"])
def test_regularized_csp(ch_type, rank, reg):
    """Test Common Spatial Patterns algorithm using regularized covariance."""
    raw = io.read_raw_fif(raw_fname).pick(ch_type, exclude="bads").load_data()
    n_orig = len(raw.ch_names)
    ch_decim = 2
    raw.pick_channels(raw.ch_names[::ch_decim])
    raw.info.normalize_proj()
    if "eeg" in ch_type:
        raw.set_eeg_reference(projection=True)
        # TODO: for some reason we need to add a second EEG projector in order to get
        # the non-semidefinite error for EEG data. Hopefully this won't make much
        # difference in practice given our default is rank=None and regularization
        # is easy to use.
        raw.add_proj(compute_proj_raw(raw, n_eeg=1, n_mag=0, n_grad=0, n_jobs=1))
    n_eig = len(raw.ch_names) - len(raw.info["projs"])
    n_ch = n_orig // ch_decim
    if ch_type == "eeg":
        assert n_eig == n_ch - 2
    elif ch_type == "mag":
        assert n_eig == n_ch - 3
    else:
        assert n_eig == n_ch - 5
    if rank == "correct":
        if isinstance(ch_type, str):
            rank = {ch_type: n_eig}
        else:
            assert ch_type == ("mag", "eeg")
            rank = dict(
                mag=102 // ch_decim - 3,
                eeg=60 // ch_decim - 2,
            )
    else:
        assert rank is None or rank == "full", rank
    if rank == "full":
        n_eig = n_ch
    raw.filter(2, 40).apply_proj()
    events = read_events(event_name)
    # map make left and right events the same
    events[events[:, 2] == 2, 2] = 1
    events[events[:, 2] == 4, 2] = 3
    epochs = Epochs(raw, events, event_id, tmin, tmax, decim=5, preload=True)
    epochs.equalize_event_counts()
    assert 25 < len(epochs) < 30
    epochs_data = epochs.get_data(copy=False)
    n_channels = epochs_data.shape[1]
    assert n_channels == n_ch
    n_components = 3

    sc = Scaler(epochs.info)
    epochs_data_orig = epochs_data.copy()
    epochs_data = sc.fit_transform(epochs_data)
    csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=rank)
    if rank == "full" and reg is None:
        with pytest.raises(np.linalg.LinAlgError, match="leading minor"):
            csp.fit(epochs_data, epochs.events[:, -1])
        return
    with catch_logging(verbose=True) as log:
        X = csp.fit_transform(epochs_data, epochs.events[:, -1])
    log = log.getvalue()
    assert "Setting small MAG" not in log
    if rank != "full":
        assert "Setting small data eigen" in log
    else:
        assert "Setting small data eigen" not in log
    if rank is None:
        assert "Computing rank from data" in log
        assert " mag: rank" not in log.lower()
        assert " data: rank" in log
        assert "rank (mag)" not in log.lower()
        assert "rank (data)" in log
    elif rank != "full":  # if rank is passed no computation is done
        assert "Computing rank" not in log
        assert ": rank" not in log
        assert "rank (" not in log
    assert "reducing mag" not in log.lower()
    assert f"Reducing data rank from {n_channels} " in log
    y = epochs.events[:, -1]
    assert csp.filters_.shape == (n_eig, n_channels)
    assert csp.patterns_.shape == (n_eig, n_channels)
    assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X)

    # test init exception
    pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events))
    pytest.raises(ValueError, csp.fit, epochs, y)
    pytest.raises(ValueError, csp.transform, epochs)

    csp.n_components = n_components
    sources = csp.transform(epochs_data)
    assert sources.shape[1] == n_components

    cv = StratifiedKFold(5)
    clf = make_pipeline(
        sc,
        csp,
        LinearModel(LogisticRegression(solver="liblinear")),
    )
    score = cross_val_score(clf, epochs_data_orig, y, cv=cv, scoring="roc_auc").mean()
    assert 0.75 <= score <= 1.0

    # Test get_coef on CSP
    clf.fit(epochs_data_orig, y)
    coef = csp.patterns_[:n_components]
    assert coef.shape == (n_components, n_channels), coef.shape
    coef = sc.inverse_transform(coef.T[np.newaxis])[0]
    assert coef.shape == (len(epochs.ch_names), n_components), coef.shape
    coef_mne = get_coef(clf, "patterns_", inverse_transform=True, verbose="debug")
    assert coef.shape == coef_mne.shape
    coef_mne /= np.linalg.norm(coef_mne, axis=0)
    coef /= np.linalg.norm(coef, axis=0)
    coef *= np.sign(np.sum(coef_mne * coef, axis=0))
    assert_allclose(coef_mne, coef)


def test_csp_pipeline():
    """Test if CSP works in a pipeline."""
    csp = CSP(reg=1, norm_trace=False)
    svc = SVC()
    pipe = Pipeline([("CSP", csp), ("SVC", svc)])
    pipe.set_params(CSP__reg=0.2)
    assert pipe.get_params()["CSP__reg"] == 0.2


def test_ajd():
    """Test approximate joint diagonalization."""
    # The implementation should obtain the same
    # results as the Matlab implementation by Pham Dinh-Tuan.
    # Generate a set of cavariances matrices for test purpose
    n_times, n_channels = 10, 3
    seed = np.random.RandomState(0)
    diags = 2.0 + 0.1 * seed.randn(n_times, n_channels)
    A = 2 * seed.rand(n_channels, n_channels) - 1
    A /= np.atleast_2d(np.sqrt(np.sum(A**2, 1))).T
    covmats = np.empty((n_times, n_channels, n_channels))
    for i in range(n_times):
        covmats[i] = np.dot(np.dot(A, np.diag(diags[i])), A.T)
    V, D = _ajd_pham(covmats)
    # Results obtained with original matlab implementation
    V_matlab = [
        [-3.507280775058041, -5.498189967306344, 7.720624541198574],
        [0.694689013234610, 0.775690358505945, -1.162043086446043],
        [-0.592603135588066, -0.598996925696260, 1.009550086271192],
    ]
    assert_array_almost_equal(V, V_matlab)


def test_spoc():
    """Test SPoC."""
    X = np.random.randn(10, 10, 20)
    y = np.random.randn(10)

    spoc = SPoC(n_components=4)
    spoc.fit(X, y)
    Xt = spoc.transform(X)
    assert_array_equal(Xt.shape, [10, 4])
    spoc = SPoC(n_components=4, transform_into="csp_space")
    spoc.fit(X, y)
    Xt = spoc.transform(X)
    assert_array_equal(Xt.shape, [10, 4, 20])
    assert_array_equal(spoc.filters_.shape, [10, 10])
    assert_array_equal(spoc.patterns_.shape, [10, 10])

    # check y
    pytest.raises(ValueError, spoc.fit, X, y * 0)

    # Check that doesn't take CSP-spcific input
    pytest.raises(TypeError, SPoC, cov_est="epoch")

    # Check mixing matrix on simulated data
    rs = np.random.RandomState(42)
    y = rs.rand(100) * 50 + 1
    X, A = simulate_data(y)

    # fit spoc
    spoc = SPoC(n_components=1)
    spoc.fit(X, y)

    # check the first patterns match the mixing matrix
    corr = np.abs(np.corrcoef(spoc.patterns_[0, :].T, A[:, 0])[0, 1])
    assert np.abs(corr) > 0.99

    # check output
    out = spoc.transform(X)
    corr = np.abs(np.corrcoef(out[:, 0], y)[0, 1])
    assert np.abs(corr) > 0.85


def test_csp_twoclass_symmetry():
    """Test that CSP is symmetric when swapping classes."""
    x, y = deterministic_toy_data(["class_a", "class_b"])
    csp = CSP(norm_trace=False, transform_into="average_power", log=True)
    log_power = csp.fit_transform(x, y)
    log_power_ratio_ab = log_power[0] - log_power[1]

    x, y = deterministic_toy_data(["class_b", "class_a"])
    csp = CSP(norm_trace=False, transform_into="average_power", log=True)
    log_power = csp.fit_transform(x, y)
    log_power_ratio_ba = log_power[0] - log_power[1]

    assert_array_almost_equal(log_power_ratio_ab, log_power_ratio_ba)


def test_csp_component_ordering():
    """Test that CSP component ordering works as expected."""
    x, y = deterministic_toy_data(["class_a", "class_b"])

    pytest.raises(ValueError, CSP, component_order="invalid")

    # component_order='alternate' only works with two classes
    csp = CSP(component_order="alternate")
    with pytest.raises(ValueError):
        csp.fit(np.zeros((3, 0, 0)), ["a", "b", "c"])

    p_alt = CSP(component_order="alternate").fit(x, y).patterns_
    p_mut = CSP(component_order="mutual_info").fit(x, y).patterns_

    # This permutation of p_alt and p_mut is explained by the particular
    # eigenvalues of the toy data: [0.06, 0.1,   0.5,  0.8].
    # p_alt arranges them to [0.8, 0.06, 0.5, 0.1]
    # p_mut arranges them to [0.06, 0.1, 0.8, 0.5]
    assert_array_almost_equal(p_alt, p_mut[[2, 0, 3, 1]])
