# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal

from mne import (
    convert_forward_solution,
    pick_types_forward,
    read_forward_solution,
    read_label,
)
from mne.datasets import testing
from mne.label import Label
from mne.simulation import SourceSimulator, simulate_sparse_stc, simulate_stc

data_path = testing.data_path(download=False)
fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif"
label_names = ["Aud-lh", "Aud-rh", "Vis-rh"]
subjects_dir = data_path / "subjects"


@pytest.fixture(scope="module", params=[testing._pytest_param()])
def _get_fwd_labels():
    fwd = read_forward_solution(fname_fwd)
    fwd = convert_forward_solution(fwd, force_fixed=True, use_cps=True)
    fwd = pick_types_forward(fwd, meg=True, eeg=False)
    labels = [
        read_label(data_path / "MEG" / "sample" / "labels" / f"{label}.label")
        for label in label_names
    ]
    return fwd, labels


def _get_idx_label_stc(label, stc):
    hemi_idx_mapping = dict(lh=0, rh=1)

    hemi_idx = hemi_idx_mapping[label.hemi]

    idx = np.intersect1d(stc.vertices[hemi_idx], label.vertices)
    idx = np.searchsorted(stc.vertices[hemi_idx], idx)

    if hemi_idx == 1:
        idx += len(stc.vertices[0])
    return idx


def test_simulate_stc(_get_fwd_labels):
    """Test generation of source estimate."""
    fwd, labels = _get_fwd_labels
    mylabels = []
    for i, label in enumerate(labels):
        new_label = Label(
            vertices=label.vertices,
            pos=label.pos,
            values=2 * i * np.ones(len(label.values)),
            hemi=label.hemi,
            comment=label.comment,
        )
        mylabels.append(new_label)

    n_times = 10
    tmin = 0
    tstep = 1e-3

    stc_data = np.ones((len(labels), n_times))
    stc = simulate_stc(fwd["src"], mylabels, stc_data, tmin, tstep)
    assert_equal(stc.subject, "sample")

    for label in labels:
        idx = _get_idx_label_stc(label, stc)
        assert np.all(stc.data[idx] == 1.0)
        assert stc.data[idx].shape[1] == n_times

    # test with function
    def fun(x):
        return x**2

    stc = simulate_stc(fwd["src"], mylabels, stc_data, tmin, tstep, fun)

    # the first label has value 0, the second value 2, the third value 6

    for i, label in enumerate(labels):
        idx = _get_idx_label_stc(label, stc)
        res = ((2.0 * i) ** 2.0) * np.ones((len(idx), n_times))
        assert_array_almost_equal(stc.data[idx], res)

    # degenerate conditions
    label_subset = mylabels[:2]
    data_subset = stc_data[:2]
    stc = simulate_stc(fwd["src"], label_subset, data_subset, tmin, tstep, fun)

    pytest.raises(
        ValueError,
        simulate_stc,
        fwd["src"],
        label_subset,
        data_subset[:-1],
        tmin,
        tstep,
        fun,
    )
    pytest.raises(
        RuntimeError,
        simulate_stc,
        fwd["src"],
        label_subset * 2,
        np.concatenate([data_subset] * 2, axis=0),
        tmin,
        tstep,
        fun,
    )

    i = np.where(fwd["src"][0]["inuse"] == 0)[0][0]
    label_single_vert = Label(
        vertices=[i], pos=fwd["src"][0]["rr"][i : i + 1, :], hemi="lh"
    )
    stc = simulate_stc(fwd["src"], [label_single_vert], stc_data[:1], tmin, tstep)
    assert_equal(len(stc.lh_vertno), 1)


def test_simulate_sparse_stc(_get_fwd_labels):
    """Test generation of sparse source estimate."""
    pytest.importorskip("nibabel")
    fwd, labels = _get_fwd_labels
    n_times = 10
    tmin = 0
    tstep = 1e-3
    times = np.arange(n_times, dtype=np.float64) * tstep + tmin

    pytest.raises(
        ValueError,
        simulate_sparse_stc,
        fwd["src"],
        len(labels),
        times,
        labels=labels,
        location="center",
        subject="sample",
        subjects_dir=subjects_dir,
    )  # no non-zero values

    mylabels = []
    for label in labels:
        this_label = label.copy()
        this_label.values.fill(1.0)
        mylabels.append(this_label)

    for location in ("random", "center"):
        random_state = 0 if location == "random" else None
        stc_1 = simulate_sparse_stc(
            fwd["src"],
            len(mylabels),
            times,
            labels=mylabels,
            random_state=random_state,
            location=location,
            subjects_dir=subjects_dir,
        )
        assert_equal(stc_1.subject, "sample")

        assert stc_1.data.shape[0] == len(mylabels)
        assert stc_1.data.shape[1] == n_times

        # make sure we get the same result when using the same seed
        stc_2 = simulate_sparse_stc(
            fwd["src"],
            len(mylabels),
            times,
            labels=mylabels,
            random_state=random_state,
            location=location,
            subjects_dir=subjects_dir,
        )

        assert_array_equal(stc_1.lh_vertno, stc_2.lh_vertno)
        assert_array_equal(stc_1.rh_vertno, stc_2.rh_vertno)

    # Degenerate cases
    pytest.raises(
        ValueError,
        simulate_sparse_stc,
        fwd["src"],
        len(mylabels),
        times,
        labels=mylabels,
        location="center",
        subject="foo",
        subjects_dir=subjects_dir,
    )  # wrong subject
    del fwd["src"][0]["subject_his_id"]  # remove subject
    pytest.raises(
        ValueError,
        simulate_sparse_stc,
        fwd["src"],
        len(mylabels),
        times,
        labels=mylabels,
        location="center",
        subjects_dir=subjects_dir,
    )  # no subject
    fwd["src"][0]["subject_his_id"] = "sample"  # put back subject
    pytest.raises(
        ValueError,
        simulate_sparse_stc,
        fwd["src"],
        len(mylabels),
        times,
        labels=mylabels,
        location="foo",
    )  # bad location
    err_str = "Number of labels"
    with pytest.raises(ValueError, match=err_str):
        simulate_sparse_stc(
            fwd["src"],
            len(mylabels) + 1,
            times,
            labels=mylabels,
            random_state=random_state,
            location=location,
            subjects_dir=subjects_dir,
        )


def test_generate_stc_single_hemi(_get_fwd_labels):
    """Test generation of source estimate, single hemi."""
    fwd, labels = _get_fwd_labels
    labels_single_hemi = labels[1:]  # keep only labels in one hemisphere

    mylabels = []
    for i, label in enumerate(labels_single_hemi):
        new_label = Label(
            vertices=label.vertices,
            pos=label.pos,
            values=2 * i * np.ones(len(label.values)),
            hemi=label.hemi,
            comment=label.comment,
        )
        mylabels.append(new_label)

    n_times = 10
    tmin = 0
    tstep = 1e-3

    stc_data = np.ones((len(labels_single_hemi), n_times))
    stc = simulate_stc(fwd["src"], mylabels, stc_data, tmin, tstep)

    for label in labels_single_hemi:
        idx = _get_idx_label_stc(label, stc)
        assert np.all(stc.data[idx] == 1.0)
        assert stc.data[idx].shape[1] == n_times

    # test with function
    def fun(x):
        return x**2

    stc = simulate_stc(fwd["src"], mylabels, stc_data, tmin, tstep, fun)

    # the first label has value 0, the second value 2, the third value 6

    for i, label in enumerate(labels_single_hemi):
        if label.hemi == "lh":
            hemi_idx = 0
        else:
            hemi_idx = 1

        idx = np.intersect1d(stc.vertices[hemi_idx], label.vertices)
        idx = np.searchsorted(stc.vertices[hemi_idx], idx)

        if hemi_idx == 1:
            idx += len(stc.vertices[0])

        res = ((2.0 * i) ** 2.0) * np.ones((len(idx), n_times))
        assert_array_almost_equal(stc.data[idx], res)


def test_simulate_sparse_stc_single_hemi(_get_fwd_labels):
    """Test generation of sparse source estimate."""
    fwd, labels = _get_fwd_labels
    labels_single_hemi = labels[1:]  # keep only labels in one hemisphere

    n_times = 10
    tmin = 0
    tstep = 1e-3
    times = np.arange(n_times, dtype=np.float64) * tstep + tmin

    stc_1 = simulate_sparse_stc(
        fwd["src"],
        len(labels_single_hemi),
        times,
        labels=labels_single_hemi,
        random_state=0,
    )

    assert stc_1.data.shape[0] == len(labels_single_hemi)
    assert stc_1.data.shape[1] == n_times

    # make sure we get the same result when using the same seed
    stc_2 = simulate_sparse_stc(
        fwd["src"],
        len(labels_single_hemi),
        times,
        labels=labels_single_hemi,
        random_state=0,
    )

    assert_array_equal(stc_1.lh_vertno, stc_2.lh_vertno)
    assert_array_equal(stc_1.rh_vertno, stc_2.rh_vertno)


@testing.requires_testing_data
def test_simulate_stc_labels_overlap(_get_fwd_labels):
    """Test generation of source estimate, overlapping labels."""
    fwd, labels = _get_fwd_labels
    mylabels = []
    for i, label in enumerate(labels):
        new_label = Label(
            vertices=label.vertices,
            pos=label.pos,
            values=2 * i * np.ones(len(label.values)),
            hemi=label.hemi,
            comment=label.comment,
        )
        mylabels.append(new_label)
    # Adding the last label twice
    mylabels.append(new_label)

    n_times = 10
    tmin = 0
    tstep = 1e-3

    stc_data = np.ones((len(mylabels), n_times))

    # Test false
    with pytest.raises(RuntimeError, match="must be non-overlapping"):
        simulate_stc(fwd["src"], mylabels, stc_data, tmin, tstep, allow_overlap=False)
    # test True
    stc = simulate_stc(fwd["src"], mylabels, stc_data, tmin, tstep, allow_overlap=True)
    assert_equal(stc.subject, "sample")
    assert stc.data.shape[1] == n_times
    # Some of the elements should be equal to 2 since we have duplicate labels
    assert 2 in stc.data


def test_source_simulator(_get_fwd_labels):
    """Test Source Simulator."""
    fwd, _ = _get_fwd_labels

    src = fwd["src"]
    hemi_to_ind = {"lh": 0, "rh": 1}
    tstep = 1.0 / 6.0

    label_vertices = [[], [], []]
    label_vertices[0] = np.arange(1000)
    label_vertices[1] = np.arange(500, 1500)
    label_vertices[2] = np.arange(1000)

    hemis = ["lh", "lh", "rh"]

    mylabels = []
    src_vertices = []
    for i, vert in enumerate(label_vertices):
        new_label = Label(vertices=vert, hemi=hemis[i])
        mylabels.append(new_label)
        src_vertices.append(
            np.intersect1d(src[hemi_to_ind[hemis[i]]]["vertno"], new_label.vertices)
        )

    wfs = [[], [], []]

    wfs[0] = np.array([0, 1.0, 0])  # 1d array
    wfs[1] = [np.array([0, 1.0, 0]), np.array([0, 1.5, 0])]  # list
    wfs[2] = np.array([[1, 1, 1.0]])  # 2d array

    events = [[], [], []]
    events[0] = np.array([[0, 0, 1], [3, 0, 1]])
    events[1] = np.array([[0, 0, 1], [3, 0, 1]])
    events[2] = np.array([[0, 0, 1], [2, 0, 1]])

    verts_lh = np.intersect1d(range(1500), src[0]["vertno"])
    verts_rh = np.intersect1d(range(1000), src[1]["vertno"])
    diff_01 = len(np.setdiff1d(src_vertices[0], src_vertices[1]))
    diff_10 = len(np.setdiff1d(src_vertices[1], src_vertices[0]))
    inter_10 = len(np.intersect1d(src_vertices[1], src_vertices[0]))

    output_data_lh = np.zeros([len(verts_lh), 6])
    tmp = np.array([0, 1.0, 0, 0, 1, 0])
    output_data_lh[:diff_01, :] = np.tile(tmp, (diff_01, 1))

    tmp = np.array([0, 2, 0, 0, 2.5, 0])
    output_data_lh[diff_01 : diff_01 + inter_10, :] = np.tile(tmp, (inter_10, 1))
    tmp = np.array([0, 1, 0, 0, 1.5, 0])
    output_data_lh[diff_01 + inter_10 :, :] = np.tile(tmp, (diff_10, 1))

    data_rh_wf = np.array([1.0, 1, 2, 1, 1, 0])
    output_data_rh = np.tile(data_rh_wf, (len(src_vertices[2]), 1))
    output_data = np.vstack([output_data_lh, output_data_rh])

    ss = SourceSimulator(src, tstep)
    for i in range(3):
        ss.add_data(mylabels[i], wfs[i], events[i])

    stc = ss.get_stc()
    stim_channel = ss.get_stim_channel()

    # Make some size checks.
    assert ss.duration == 1.0
    assert ss.n_times == 6
    assert ss.last_samp == 5

    assert len(stim_channel) == stc.data.shape[1]
    assert np.all(stc.vertices[0] == verts_lh)
    assert np.all(stc.vertices[1] == verts_rh)
    assert_array_almost_equal(stc.lh_data, output_data_lh)
    assert_array_almost_equal(stc.rh_data, output_data_rh)
    assert_array_almost_equal(stc.data, output_data)

    counter = 0
    for stc, stim in ss:
        assert stc.data.shape[1] == 6
        counter += 1
    assert counter == 1

    # Check validity of setting duration and start/stop parameters.
    half_ss = SourceSimulator(src, tstep, duration=0.5)
    for i in range(3):
        half_ss.add_data(mylabels[i], wfs[i], events[i])
    with pytest.raises(TypeError, match="array of integers"):
        half_ss.add_data(mylabels[0], wfs[0], events[0].astype(float))
    half_stc = half_ss.get_stc()
    assert_array_almost_equal(stc.data[:, :3], half_stc.data)

    part_stc = ss.get_stc(start_sample=1, stop_sample=4)
    assert part_stc.shape == (24, 4)
    assert part_stc.times[0] == tstep

    # Check validity of other arguments.
    with pytest.raises(ValueError, match="start_sample must be"):
        ss.get_stc(2, 0)
    ss = SourceSimulator(src)
    with pytest.raises(ValueError, match="No simulation parameters"):
        ss.get_stc()
    with pytest.raises(TypeError, match="must be an instance of Label"):
        ss.add_data(1, wfs, events)
    with pytest.raises(ValueError, match="Number of waveforms and events should match"):
        ss.add_data(mylabels[0], wfs[:2], events)
    with pytest.raises(ValueError, match="duration must be None or"):
        ss = SourceSimulator(src, tstep, tstep / 2)

    # Verify first_samp functionality.
    ss = SourceSimulator(src, tstep)
    offset = 50
    for i in range(3):  # events are offset, but first_samp = 0
        events[i][:, 0] += offset
        ss.add_data(mylabels[i], wfs[i], events[i])
    offset_stc = ss.get_stc()
    assert ss.n_times == 56
    assert ss.first_samp == 0
    assert offset_stc.data.shape == (stc.data.shape[0], stc.data.shape[1] + offset)
    ss = SourceSimulator(src, tstep, first_samp=offset)
    for i in range(3):  # events still offset, but first_samp > 0
        ss.add_data(mylabels[i], wfs[i], events[i])
    offset_stc = ss.get_stc()
    assert ss.n_times == 6
    assert ss.first_samp == offset
    assert ss.last_samp == offset + 5
    assert offset_stc.data.shape == stc.data.shape

    # Verify that the chunks have the correct length.
    source_simulator = SourceSimulator(src, tstep=tstep, duration=10 * tstep)
    source_simulator.add_data(mylabels[0], np.array([1, 1, 1]), [[0, 0, 0]])

    source_simulator._chk_duration = 6  # Quick hack to get short chunks.
    stcs = [stc for stc, _ in source_simulator]
    assert len(stcs) == 2
    assert stcs[0].data.shape[1] == 6
    assert stcs[1].data.shape[1] == 4
