# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_equal

from mne import Epochs, find_events, pick_types
from mne._fiff.meas_info import create_info
from mne._fiff.pick import get_channel_type_constants
from mne.channels import make_dig_montage
from mne.io import read_raw_fif
from mne.io.array import RawArray
from mne.io.tests.test_raw import _test_raw_reader

base_dir = Path(__file__).parents[2] / "tests" / "data"
fif_fname = base_dir / "test_raw.fif"


def test_long_names():
    """Test long name support."""
    info = create_info(["a" * 15 + "b", "a" * 16], 1000.0, verbose="error")
    data = np.zeros((2, 1000))
    raw = RawArray(data, info)
    assert raw.ch_names == ["a" * 15 + "b", "a" * 16]
    # and a way to get the old behavior
    raw.rename_channels(
        {k: k[:13] for k in raw.ch_names}, allow_duplicates=True, verbose="error"
    )
    assert raw.ch_names == ["a" * 13 + "-0", "a" * 13 + "-1"]
    info = create_info(["a" * 16] * 11, 1000.0, verbose="error")
    data = np.zeros((11, 1000))
    raw = RawArray(data, info)
    assert raw.ch_names == ["a" * 16 + f"-{ii}" for ii in range(11)]


def test_array_copy():
    """Test copying during construction."""
    info = create_info(1, 1000.0)
    data = np.zeros((1, 1000))
    # 'auto' (default)
    raw = RawArray(data, info)
    assert raw._data is data
    assert raw.info is not info
    raw = RawArray(data.astype(np.float32), info)
    assert raw._data is not data
    assert raw.info is not info
    # 'info' (more restrictive)
    raw = RawArray(data, info, copy="info")
    assert raw._data is data
    assert raw.info is not info
    with pytest.raises(ValueError, match="data copying was not .* copy='info"):
        RawArray(data.astype(np.float32), info, copy="info")
    # 'data'
    raw = RawArray(data, info, copy="data")
    assert raw._data is not data
    assert raw.info is info
    # 'both'
    raw = RawArray(data, info, copy="both")
    assert raw._data is not data
    assert raw.info is not info
    raw = RawArray(data.astype(np.float32), info, copy="both")
    assert raw._data is not data
    assert raw.info is not info
    # None
    raw = RawArray(data, info, copy=None)
    assert raw._data is data
    assert raw.info is info
    with pytest.raises(ValueError, match="data copying was not .* copy=None"):
        RawArray(data.astype(np.float32), info, copy=None)


@pytest.mark.slowtest
def test_array_raw():
    """Test creating raw from array."""
    # creating
    raw = read_raw_fif(fif_fname).crop(2, 5)
    data, times = raw[:, :]
    sfreq = raw.info["sfreq"]
    ch_names = [
        (ch[4:] if "STI" not in ch else ch) for ch in raw.info["ch_names"]
    ]  # change them, why not
    types = list()
    for ci in range(101):
        types.extend(("grad", "grad", "mag"))
    types.extend(["ecog", "seeg", "hbo"])  # really 4 meg channels
    types.extend(["stim"] * 9)
    types.extend(["dbs"])  # really eeg channel
    types.extend(["eeg"] * 60)
    picks = np.concatenate(
        [
            pick_types(raw.info, meg=True)[::20],
            pick_types(raw.info, meg=False, stim=True),
            pick_types(raw.info, meg=False, eeg=True)[::20],
        ]
    )
    del raw
    data = data[picks]
    ch_names = np.array(ch_names)[picks].tolist()
    types = np.array(types)[picks].tolist()
    types.pop(-1)
    # wrong length
    pytest.raises(ValueError, create_info, ch_names, sfreq, types)
    # bad entry
    types.append("foo")
    pytest.raises(KeyError, create_info, ch_names, sfreq, types)
    types[-1] = "eog"
    # default type
    info = create_info(ch_names, sfreq)
    assert_equal(info["chs"][0]["kind"], get_channel_type_constants()["misc"]["kind"])
    # use real types
    info = create_info(ch_names, sfreq, types)
    raw2 = _test_raw_reader(
        RawArray,
        test_preloading=False,
        data=data,
        info=info,
        first_samp=2 * data.shape[1],
    )
    data2, times2 = raw2[:, :]
    assert_allclose(data, data2)
    assert_allclose(times, times2)
    assert "RawArray" in repr(raw2)
    pytest.raises(TypeError, RawArray, info, data)

    # filtering
    picks = pick_types(raw2.info, meg=True, misc=True, exclude="bads")[:4]
    assert_equal(len(picks), 4)
    raw_lp = raw2.copy()
    kwargs = dict(fir_design="firwin", picks=picks)
    raw_lp.filter(None, 4.0, h_trans_bandwidth=4.0, **kwargs)
    raw_hp = raw2.copy()
    raw_hp.filter(16.0, None, l_trans_bandwidth=4.0, **kwargs)
    raw_bp = raw2.copy()
    raw_bp.filter(8.0, 12.0, l_trans_bandwidth=4.0, h_trans_bandwidth=4.0, **kwargs)
    raw_bs = raw2.copy()
    raw_bs.filter(16.0, 4.0, l_trans_bandwidth=4.0, h_trans_bandwidth=4.0, **kwargs)
    data, _ = raw2[picks, :]
    lp_data, _ = raw_lp[picks, :]
    hp_data, _ = raw_hp[picks, :]
    bp_data, _ = raw_bp[picks, :]
    bs_data, _ = raw_bs[picks, :]
    sig_dec = 15
    assert_array_almost_equal(data, lp_data + bp_data + hp_data, sig_dec)
    assert_array_almost_equal(data, bp_data + bs_data, sig_dec)

    # plotting
    raw2.plot()
    raw2.compute_psd(tmax=2.0, n_fft=1024).plot(
        average=True, amplitude=False, spatial_colors=False
    )
    plt.close("all")

    # epoching
    events = find_events(raw2, stim_channel="STI 014")
    events[:, 2] = 1
    assert len(events) > 2
    epochs = Epochs(raw2, events, 1, -0.2, 0.4, preload=True)
    evoked = epochs.average()
    assert_equal(evoked.nave, len(events) - 1)

    # complex data
    rng = np.random.RandomState(0)
    data = rng.randn(1, 100) + 1j * rng.randn(1, 100)
    raw = RawArray(data, create_info(1, 1000.0, "eeg"))
    assert_allclose(raw._data, data)

    # Using digital montage to give MNI electrode coordinates
    n_elec = 10
    ts_size = 10000
    Fs = 512.0
    ch_names = [str(i) for i in range(n_elec)]
    ch_pos_loc = np.random.randint(60, size=(n_elec, 3)).tolist()

    data = np.random.rand(n_elec, ts_size)
    montage = make_dig_montage(
        ch_pos=dict(zip(ch_names, ch_pos_loc)), coord_frame="head"
    )
    info = create_info(ch_names, Fs, "ecog")

    raw = RawArray(data, info)
    raw.set_montage(montage)
    spectrum = raw.compute_psd()
    spectrum.plot(average=False, amplitude=False)  # looking for nonexistent layout
    spectrum.plot_topo()
