# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.


import numpy as np
import pytest
from numpy.testing import assert_allclose
from scipy.linalg import norm

from mne import SourceEstimate, read_source_spaces
from mne.datasets import testing
from mne.simulation import metrics
from mne.simulation.metrics import (
    cosine_score,
    f1_score,
    peak_position_error,
    precision_score,
    recall_score,
    region_localization_error,
    roc_auc_score,
    spatial_deviation_error,
)

data_path = testing.data_path(download=False)
src_fname = data_path / "subjects" / "sample" / "bem" / "sample-oct-6-src.fif"


@testing.requires_testing_data
def test_uniform_and_thresholding():
    """Test simulation metrics."""
    src = read_source_spaces(src_fname)
    vert = [src[0]["vertno"][0:1], []]
    data = np.array([[0.8, -1.0]])
    stc_true = SourceEstimate(data, vert, 0, 0.002, subject="sample")
    stc_bad = SourceEstimate(data, vert, 0, 0.002, subject="sample")
    stc_bad.vertices = [stc_bad.vertices[0]]
    with pytest.raises(ValueError, match="same number of vertices"):
        metrics._uniform_stc(stc_true, stc_bad)

    threshold = 0.9
    stc1, stc2 = metrics._thresholding(stc_true, stc_true, threshold)
    assert_allclose(stc1._data, np.array([[0, -1.0]]))
    assert_allclose(stc2._data, np.array([[0, -1.0]]))
    assert_allclose(threshold, metrics._check_threshold(threshold))

    threshold = "90"
    with pytest.raises(ValueError, match="Threshold if a str.*"):
        metrics._check_threshold(threshold)


@testing.requires_testing_data
def test_cosine_score():
    """Test simulation metrics."""
    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:1], []]
    vert2 = [src[0]["vertno"][1:2], []]
    data1 = np.ones((1, 2))
    data2 = data1.copy()
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est1 = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    stc_est2 = SourceEstimate(data2, vert1, 0, 0.002, subject="sample")

    E_per_sample1 = cosine_score(stc_true, stc_est1)
    E_unique1 = cosine_score(stc_true, stc_est1, per_sample=False)

    E_per_sample2 = cosine_score(stc_true, stc_est2)
    E_unique2 = cosine_score(stc_true, stc_est2, per_sample=False)

    assert_allclose(E_per_sample1, np.zeros(2))
    assert_allclose(E_unique1, 0.0, atol=1e-08)
    assert_allclose(E_per_sample2, np.ones(2))
    assert_allclose(E_unique2, 1.0, atol=1e-08)


@testing.requires_testing_data
def test_region_localization_error():
    """Test simulation metrics."""
    pytest.importorskip("sklearn")
    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:1], []]
    vert2 = [src[0]["vertno"][1:2], []]
    dist = norm(src[0]["rr"][vert1[0]] - src[0]["rr"][vert2[0]])
    data1 = np.ones((1, 2))
    data2 = np.array([[0.8, 1]])
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est1 = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")

    E_per_sample1 = region_localization_error(stc_true, stc_est1, src)
    E_per_sample2 = region_localization_error(stc_true, stc_est1, src, threshold="70%")
    E_unique = region_localization_error(stc_true, stc_est1, src, per_sample=False)

    assert_allclose(E_per_sample1, [np.inf, dist])
    assert_allclose(E_per_sample2, [dist, dist])
    assert_allclose(E_unique, dist)


@testing.requires_testing_data
def test_precision_score():
    """Test simulation metrics."""
    pytest.importorskip("sklearn")
    from sklearn.exceptions import UndefinedMetricWarning

    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:2], []]
    vert2 = [src[0]["vertno"][1:3], []]
    vert3 = [src[0]["vertno"][0:1], []]
    data1 = np.ones((2, 2))
    data2 = np.ones((2, 2))
    data3 = np.array([[0.8, 1]])
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est1 = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    stc_est2 = SourceEstimate(data3, vert3, 0, 0.002, subject="sample")

    E_unique1 = precision_score(stc_true, stc_est1, per_sample=False)
    E_unique2 = precision_score(stc_true, stc_est2, per_sample=False)
    with pytest.warns(UndefinedMetricWarning, match="no predicted samples"):
        E_per_sample1 = precision_score(stc_true, stc_est2)
    E_per_sample2 = precision_score(stc_true, stc_est2, threshold="70%")
    with pytest.raises(ValueError, match="0 and 1"):
        precision_score(stc_true, stc_est2, threshold=2)

    # ### Tests to add
    assert_allclose(E_unique1, 0.5)
    assert_allclose(E_unique2, 1.0)
    assert_allclose(E_per_sample1, [0.0, 1.0])
    assert_allclose(E_per_sample2, [1.0, 1.0])


@testing.requires_testing_data
def test_recall_score():
    """Test simulation metrics."""
    pytest.importorskip("sklearn")
    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:2], []]
    vert2 = [src[0]["vertno"][1:3], []]
    vert3 = [src[0]["vertno"][0:1], []]
    data1 = np.ones((2, 2))
    data2 = np.ones((2, 2))
    data3 = np.array([[0.8, 1]])
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est1 = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    stc_est2 = SourceEstimate(data3, vert3, 0, 0.002, subject="sample")

    E_unique1 = recall_score(stc_true, stc_est1, per_sample=False)
    E_unique2 = recall_score(stc_true, stc_est2, per_sample=False)
    E_per_sample1 = recall_score(stc_true, stc_est2)
    E_per_sample2 = recall_score(stc_true, stc_est2, threshold="70%")
    with pytest.raises(TypeError, match="numeric"):
        precision_score(stc_true, stc_est2, threshold=None)

    # ### Tests to add
    assert_allclose(E_unique1, 0.5)
    assert_allclose(E_unique2, 0.5)
    assert_allclose(E_per_sample1, [0.0, 0.5])
    assert_allclose(E_per_sample2, [0.5, 0.5])


@testing.requires_testing_data
def test_f1_score():
    """Test simulation metrics."""
    pytest.importorskip("sklearn")
    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:2], []]
    vert2 = [src[0]["vertno"][1:3], []]
    vert3 = [src[0]["vertno"][0:1], []]
    data1 = np.ones((2, 2))
    data2 = np.ones((2, 2))
    data3 = np.array([[0.8, 1]])
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est1 = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    stc_est2 = SourceEstimate(data3, vert3, 0, 0.002, subject="sample")

    E_unique1 = f1_score(stc_true, stc_est1, per_sample=False)
    E_unique2 = f1_score(stc_true, stc_est2, per_sample=False)
    E_per_sample1 = f1_score(stc_true, stc_est2)
    E_per_sample2 = f1_score(stc_true, stc_est2, threshold="70%")
    assert_allclose(E_unique1, 0.5)
    assert_allclose(E_unique2, 1.0 / 1.5)
    assert_allclose(E_per_sample1, [0.0, 1.0 / 1.5])
    assert_allclose(E_per_sample2, [1.0 / 1.5, 1.0 / 1.5])


@testing.requires_testing_data
def test_roc_auc_score():
    """Test simulation metrics."""
    pytest.importorskip("sklearn")
    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:4], []]
    vert2 = [src[0]["vertno"][0:4], []]
    data1 = np.array([[0.0, 0.0, 1, 1]]).T
    data2 = np.array([[0.1, -0.4, 0.35, 0.8]]).T
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")

    score = roc_auc_score(stc_true, stc_est, per_sample=False)
    assert_allclose(score, 0.75)


@testing.requires_testing_data
def test_peak_position_error():
    """Test simulation metrics."""
    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:1], []]
    vert2 = [src[0]["vertno"][0:2], []]
    data1 = np.array([[1]])
    data2 = np.array([[1, 1.0]]).T
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    r_mean = 0.5 * (src[0]["rr"][vert2[0][0]] + src[0]["rr"][vert2[0][1]])
    r_true = src[0]["rr"][vert2[0][0]]
    score = peak_position_error(stc_true, stc_est, src, per_sample=False)

    assert_allclose(score, norm(r_true - r_mean))
    with pytest.raises(ValueError, match="must contain only one dipole"):
        peak_position_error(stc_est, stc_est, src)

    data2 = np.array([[0, 0.0]]).T
    stc_est = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    score = peak_position_error(stc_true, stc_est, src, per_sample=False)
    assert_allclose(score, np.inf)


@testing.requires_testing_data
def test_spatial_deviation():
    """Test simulation metrics."""
    src = read_source_spaces(src_fname)
    vert1 = [src[0]["vertno"][0:1], []]
    vert2 = [src[0]["vertno"][0:2], []]
    data1 = np.array([[1]])
    data2 = np.array([[1, 1.0]]).T
    stc_true = SourceEstimate(data1, vert1, 0, 0.002, subject="sample")
    stc_est = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    std = np.sqrt(
        0.5 * (0 + norm(src[0]["rr"][vert2[0][1]] - src[0]["rr"][vert2[0][0]]) ** 2)
    )
    score = spatial_deviation_error(stc_true, stc_est, src, per_sample=False)
    assert_allclose(score, std)

    data2 = np.array([[0, 0.0]]).T
    stc_est = SourceEstimate(data2, vert2, 0, 0.002, subject="sample")
    score = spatial_deviation_error(stc_true, stc_est, src, per_sample=False)
    assert_allclose(score, np.inf)
