# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import itertools
from contextlib import nullcontext
from pathlib import Path

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

from mne import (
    Epochs,
    Projection,
    add_reference_channels,
    create_info,
    find_events,
    make_forward_solution,
    make_sphere_model,
    pick_channels,
    pick_channels_forward,
    pick_types,
    read_events,
    read_evokeds,
    set_bipolar_reference,
    set_eeg_reference,
    setup_volume_source_space,
)
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne._fiff.reference import _apply_reference
from mne.datasets import testing
from mne.epochs import BaseEpochs, make_fixed_length_epochs
from mne.io import RawArray, read_raw_fif
from mne.utils import _record_warnings, catch_logging

base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
data_dir = testing.data_path(download=False) / "MEG" / "sample"
fif_fname = data_dir / "sample_audvis_trunc_raw.fif"
eve_fname = data_dir / "sample_audvis_trunc_raw-eve.fif"
ave_fname = data_dir / "sample_audvis-ave.fif"


def _test_reference(raw, reref, ref_data, ref_from):
    """Test whether a reference has been correctly applied."""
    # Separate EEG channels from other channel types
    picks_eeg = pick_types(raw.info, meg=False, eeg=True, exclude="bads")
    picks_other = pick_types(
        raw.info, meg=True, eeg=False, eog=True, stim=True, exclude="bads"
    )

    # Calculate indices of reference channesl
    picks_ref = [raw.ch_names.index(ch) for ch in ref_from]

    # Get data
    _data = raw._data
    _reref = reref._data

    # Check that the ref has been properly computed
    if ref_data is not None:
        assert_array_equal(ref_data, _data[..., picks_ref, :].mean(-2))

    # Get the raw EEG data and other channel data
    raw_eeg_data = _data[..., picks_eeg, :]
    raw_other_data = _data[..., picks_other, :]

    # Get the rereferenced EEG data
    reref_eeg_data = _reref[..., picks_eeg, :]
    reref_other_data = _reref[..., picks_other, :]

    # Check that non-EEG channels are untouched
    assert_allclose(raw_other_data, reref_other_data, 1e-6, atol=1e-15)

    # Undo rereferencing of EEG channels if possible
    if ref_data is not None:
        if isinstance(raw, BaseEpochs):
            unref_eeg_data = reref_eeg_data + ref_data[:, np.newaxis, :]
        else:
            unref_eeg_data = reref_eeg_data + ref_data
        assert_allclose(raw_eeg_data, unref_eeg_data, 1e-6, atol=1e-15)


@testing.requires_testing_data
def test_apply_reference():
    """Test base function for rereferencing."""
    raw = read_raw_fif(fif_fname, preload=True)

    # Rereference raw data by creating a copy of original data
    reref, ref_data = _apply_reference(raw.copy(), ref_from=["EEG 001", "EEG 002"])
    assert reref.info["custom_ref_applied"]
    _test_reference(raw, reref, ref_data, ["EEG 001", "EEG 002"])

    # The CAR reference projection should have been removed by the function
    assert not _has_eeg_average_ref_proj(reref.info)

    # Test that data is modified in place when copy=False
    reref, ref_data = _apply_reference(raw, ["EEG 001", "EEG 002"])
    assert raw is reref

    # Test that disabling the reference does not change anything
    reref, ref_data = _apply_reference(raw.copy(), [])
    assert_array_equal(raw._data, reref._data)

    # Test re-referencing Epochs object
    raw = read_raw_fif(fif_fname, preload=False)
    events = read_events(eve_fname)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    epochs = Epochs(
        raw,
        events=events,
        event_id=1,
        tmin=-0.2,
        tmax=0.5,
        picks=picks_eeg,
        preload=True,
    )
    reref, ref_data = _apply_reference(epochs.copy(), ref_from=["EEG 001", "EEG 002"])
    assert reref.info["custom_ref_applied"]
    _test_reference(epochs, reref, ref_data, ["EEG 001", "EEG 002"])

    # Test re-referencing Evoked object
    evoked = epochs.average()
    reref, ref_data = _apply_reference(evoked.copy(), ref_from=["EEG 001", "EEG 002"])
    assert reref.info["custom_ref_applied"]
    _test_reference(evoked, reref, ref_data, ["EEG 001", "EEG 002"])

    # Referencing needs data to be preloaded
    raw_np = read_raw_fif(fif_fname, preload=False)
    pytest.raises(RuntimeError, _apply_reference, raw_np, ["EEG 001"])

    # Test having inactive SSP projections that deal with channels involved
    # during re-referencing
    raw = read_raw_fif(fif_fname, preload=True)
    raw.add_proj(
        Projection(
            active=False,
            data=dict(
                col_names=["EEG 001", "EEG 002"],
                row_names=None,
                data=np.array([[1, 1]]),
                ncol=2,
                nrow=1,
            ),
            desc="test",
            kind=1,
        )
    )
    # Projection concerns channels mentioned in projector
    with pytest.raises(RuntimeError, match="Inactive signal space"):
        _apply_reference(raw, ["EEG 001"])

    # Projection does not concern channels mentioned in projector, no error
    _apply_reference(raw, ["EEG 003"], ["EEG 004"])

    # CSD cannot be rereferenced
    with raw.info._unlock():
        raw.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_CSD
    with pytest.raises(RuntimeError, match="Cannot set.* type 'CSD'"):
        raw.set_eeg_reference()


@testing.requires_testing_data
def test_set_eeg_reference():
    """Test rereference eeg data."""
    raw = read_raw_fif(fif_fname, preload=True)
    with raw.info._unlock():
        raw.info["projs"] = []

    # Test setting an average reference projection
    assert not _has_eeg_average_ref_proj(raw.info)
    reref, ref_data = set_eeg_reference(raw, projection=True)
    assert _has_eeg_average_ref_proj(reref.info)
    assert not reref.info["projs"][0]["active"]
    assert ref_data is None
    reref.apply_proj()
    eeg_chans = [raw.ch_names[ch] for ch in pick_types(raw.info, meg=False, eeg=True)]
    _test_reference(
        raw, reref, ref_data, [ch for ch in eeg_chans if ch not in raw.info["bads"]]
    )

    # Test setting an average reference when one was already present
    with pytest.warns(RuntimeWarning, match="untouched"):
        reref, ref_data = set_eeg_reference(raw, copy=False, projection=True)
    assert ref_data is None

    # Test setting an average reference on non-preloaded data
    raw_nopreload = read_raw_fif(fif_fname, preload=False)
    with raw_nopreload.info._unlock():
        raw_nopreload.info["projs"] = []
    reref, ref_data = set_eeg_reference(raw_nopreload, projection=True)
    assert _has_eeg_average_ref_proj(reref.info)
    assert not reref.info["projs"][0]["active"]

    # Rereference raw data by creating a copy of original data
    reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"], copy=True)
    assert reref.info["custom_ref_applied"]
    _test_reference(raw, reref, ref_data, ["EEG 001", "EEG 002"])

    # Test that data is modified in place when copy=False
    reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"], copy=False)
    assert raw is reref

    # Test moving from custom to average reference
    reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"])
    reref, _ = set_eeg_reference(reref, projection=True)
    assert _has_eeg_average_ref_proj(reref.info)
    assert not reref.info["custom_ref_applied"]

    # When creating an average reference fails, make sure the
    # custom_ref_applied flag remains untouched.
    reref = raw.copy()
    with reref.info._unlock():
        reref.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_ON
    reref.pick(picks="meg")  # Cause making average ref fail
    # should have turned it off
    assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_OFF
    with pytest.raises(ValueError, match="found to rereference"):
        set_eeg_reference(reref, projection=True)

    # Test moving from average to custom reference
    reref, ref_data = set_eeg_reference(raw, projection=True)
    reref, _ = set_eeg_reference(reref, ["EEG 001", "EEG 002"])
    assert not _has_eeg_average_ref_proj(reref.info)
    assert len(reref.info["projs"]) == 0
    assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON

    # Test that disabling the reference does not change the data
    assert _has_eeg_average_ref_proj(raw.info)
    reref, _ = set_eeg_reference(raw, [])
    assert_array_equal(raw._data, reref._data)
    assert not _has_eeg_average_ref_proj(reref.info)

    # make sure ref_channels=[] removes average reference projectors
    assert _has_eeg_average_ref_proj(raw.info)
    reref, _ = set_eeg_reference(raw, [])
    assert not _has_eeg_average_ref_proj(reref.info)

    # Test that average reference gives identical results when calculated
    # via SSP projection (projection=True) or directly (projection=False)
    with raw.info._unlock():
        raw.info["projs"] = []
    reref_1, _ = set_eeg_reference(raw.copy(), projection=True)
    reref_1.apply_proj()
    reref_2, _ = set_eeg_reference(raw.copy(), projection=False)
    assert_allclose(reref_1._data, reref_2._data, rtol=1e-6, atol=1e-15)

    # Test average reference without projection
    reref, ref_data = set_eeg_reference(
        raw.copy(), ref_channels="average", projection=False
    )
    _test_reference(raw, reref, ref_data, eeg_chans)

    with pytest.raises(ValueError, match='supported for ref_channels="averag'):
        set_eeg_reference(raw, [], True, True)
    with pytest.raises(ValueError, match='supported for ref_channels="averag'):
        set_eeg_reference(raw, ["EEG 001"], True, True)


@pytest.mark.parametrize(
    "ch_type, msg",
    [
        ("auto", ("ECoG",)),
        ("ecog", ("ECoG",)),
        ("dbs", ("DBS",)),
        (["ecog", "dbs"], ("ECoG", "DBS")),
    ],
)
@pytest.mark.parametrize("projection", [False, True])
def test_set_eeg_reference_ch_type(ch_type, msg, projection):
    """Test setting EEG reference for ECoG or DBS."""
    # gh-6454
    # gh-8739 added DBS
    ch_names = ["ECOG01", "ECOG02", "DBS01", "DBS02", "MISC"]
    rng = np.random.RandomState(0)
    data = rng.randn(5, 1000)
    raw = RawArray(
        data, create_info(ch_names, 1000.0, ["ecog"] * 2 + ["dbs"] * 2 + ["misc"])
    )

    if ch_type == "auto":
        ref_ch = ch_names[:2]
    else:
        ref_ch = raw.copy().pick(picks=ch_type).ch_names

    with catch_logging() as log:
        reref, ref_data = set_eeg_reference(
            raw.copy(), ch_type=ch_type, projection=projection, verbose=True
        )

    if not projection:
        assert f"Applying a custom {msg}" in log.getvalue()
        assert reref.info["custom_ref_applied"]  # gh-7350
    _test_reference(raw, reref, ref_data, ref_ch)
    match = "no EEG data found" if projection else "No channels supplied"
    with pytest.raises(ValueError, match=match):
        set_eeg_reference(raw, ch_type="eeg", projection=projection)
    # gh-8739
    raw2 = RawArray(data, create_info(5, 1000.0, ["mag"] * 4 + ["misc"]))
    with pytest.raises(
        ValueError, match="No EEG, ECoG, sEEG or DBS channels found to rereference."
    ):
        set_eeg_reference(raw2, ch_type="auto", projection=projection)


@testing.requires_testing_data
def test_set_eeg_reference_rest():
    """Test setting a REST reference."""
    raw = read_raw_fif(fif_fname).crop(0, 1).pick(picks="eeg").load_data()
    raw.info["bads"] = ["EEG 057"]  # should be excluded
    same = [raw.ch_names.index(raw.info["bads"][0])]
    picks = np.setdiff1d(np.arange(len(raw.ch_names)), same)
    trans = None
    sphere = make_sphere_model("auto", "auto", raw.info)
    src = setup_volume_source_space(pos=20.0, sphere=sphere, exclude=30.0)
    assert src[0]["nuse"] == 223  # low but fast
    fwd = make_forward_solution(raw.info, trans, src, sphere)
    orig_data = raw.get_data()
    avg_data = raw.copy().set_eeg_reference("average").get_data()
    assert_array_equal(avg_data[same], orig_data[same])  # not processed
    raw.set_eeg_reference("REST", forward=fwd)
    rest_data = raw.get_data()
    assert_array_equal(rest_data[same], orig_data[same])
    # should be more similar to an avg ref than nose ref
    orig_corr = np.corrcoef(rest_data[picks].ravel(), orig_data[picks].ravel())[0, 1]
    avg_corr = np.corrcoef(rest_data[picks].ravel(), avg_data[picks].ravel())[0, 1]
    assert -0.6 < orig_corr < -0.5
    assert 0.1 < avg_corr < 0.2
    # and applying an avg ref after should work
    avg_after = raw.set_eeg_reference("average").get_data()
    assert_allclose(avg_after, avg_data, atol=1e-12)
    with pytest.raises(TypeError, match='forward when ref_channels="REST"'):
        raw.set_eeg_reference("REST")
    fwd_bad = pick_channels_forward(fwd, raw.ch_names[:-1])
    with pytest.raises(ValueError, match="Missing channels"):
        raw.set_eeg_reference("REST", forward=fwd_bad)
    # compare to FieldTrip
    evoked = read_evokeds(ave_fname, baseline=(None, 0))[0]
    evoked.info["bads"] = []
    evoked.pick(picks="eeg")
    assert len(evoked.ch_names) == 60
    # Data obtained from FieldTrip with something like (after evoked.save'ing
    # then scipy.io.savemat'ing fwd['sol']['data']):
    # dat = ft_read_data('ft-ave.fif');
    # load('leadfield.mat', 'G');
    # dat_ref = ft_preproc_rereference(dat, 'all', 'rest', true, G);
    # sprintf('%g ', dat_ref(:, 171));
    data_array = "-3.3265e-05 -3.2419e-05 -3.18758e-05 -3.24079e-05 -3.39801e-05 -3.40573e-05 -3.24163e-05 -3.26896e-05 -3.33814e-05 -3.54734e-05 -3.51289e-05 -3.53229e-05 -3.51532e-05 -3.53149e-05 -3.4505e-05 -3.03462e-05 -2.81848e-05 -3.08895e-05 -3.27158e-05 -3.4605e-05 -3.47728e-05 -3.2459e-05 -3.06552e-05 -2.53255e-05 -2.69671e-05 -2.83425e-05 -3.12836e-05 -3.30965e-05 -3.34099e-05 -3.32766e-05 -3.32256e-05 -3.36385e-05 -3.20796e-05 -2.7108e-05 -2.47054e-05 -2.49589e-05 -2.7382e-05 -3.09774e-05 -3.12003e-05 -3.1246e-05 -3.07572e-05 -2.64942e-05 -2.25505e-05 -2.67194e-05 -2.86e-05 -2.94903e-05 -2.96249e-05 -2.92653e-05 -2.86472e-05 -2.81016e-05 -2.69737e-05 -2.48076e-05 -3.00473e-05 -2.73404e-05 -2.60153e-05 -2.41608e-05 -2.61937e-05 -2.5539e-05 -2.47104e-05 -2.35194e-05"  # noqa: E501
    want = np.array(data_array.split(" "), float)
    norm = np.linalg.norm(want)
    idx = np.argmin(np.abs(evoked.times - 0.083))
    assert idx == 170
    old = evoked.data[:, idx].ravel()
    exp_var = 1 - np.linalg.norm(want - old) / norm
    assert 0.006 < exp_var < 0.008
    evoked.set_eeg_reference("REST", forward=fwd)
    exp_var_old = 1 - np.linalg.norm(evoked.data[:, idx] - old) / norm
    assert 0.005 < exp_var_old <= 0.009
    exp_var = 1 - np.linalg.norm(evoked.data[:, idx] - want) / norm
    assert 0.995 < exp_var <= 1


@testing.requires_testing_data
@pytest.mark.parametrize("inst_type", ["raw", "epochs"])
@pytest.mark.parametrize(
    "ref_channels, expectation",
    [
        (
            {2: "EEG 001"},
            pytest.raises(
                TypeError,
                match="Keys in the ref_channels dict must be strings. "
                "Your dict has keys of type int.",
            ),
        ),
        (
            {"EEG 001": (1, 2)},
            pytest.raises(
                TypeError,
                match="Values in the ref_channels dict must be strings. "
                "Your dict has values of type int.",
            ),
        ),
        (
            {"EEG 001": [1, 2]},
            pytest.raises(
                TypeError,
                match="Values in the ref_channels dict must be strings. "
                "Your dict has values of type int.",
            ),
        ),
        (
            {"EEG 999": "EEG 001"},
            pytest.raises(
                ValueError,
                match=r"ref_channels dict contains invalid key\(s\) \(EEG 999\) "
                "that are not names of channels in the instance.",
            ),
        ),
        (
            {"EEG 001": "EEG 999"},
            pytest.raises(
                ValueError,
                match=r"ref_channels dict contains invalid value\(s\) \(EEG 999\) "
                "that are not names of channels in the instance.",
            ),
        ),
        (
            {"EEG 001": "EEG 057"},
            pytest.warns(
                RuntimeWarning,
                match=r"ref_channels dict contains value\(s\) \(EEG 057\) "
                "that are marked as bad channels.",
            ),
        ),
        (
            {"EEG 001": "STI 001"},
            pytest.warns(
                RuntimeWarning,
                match=(
                    r"Channel EEG 001 \(eeg\) is referenced to channel "
                    r"STI 001 which is a different channel type \(stim\)."
                ),
            ),
        ),
        (
            {"EEG 001": "EEG 001"},
            pytest.warns(
                RuntimeWarning,
                match=(
                    "Channel EEG 001 is self-referenced, "
                    "which will nullify the channel."
                ),
            ),
        ),
        (
            {"EEG 001": "EEG 002", "EEG 002": "EEG 003", "EEG 003": "EEG 005"},
            nullcontext(),
        ),
        (
            {
                "EEG 001": ["EEG 002", "EEG 003"],
                "EEG 002": "EEG 003",
                "EEG 003": "EEG 005",
            },
            nullcontext(),
        ),
    ],
)
def test_set_eeg_reference_dict(ref_channels, inst_type, expectation):
    """Test setting dict-based reference."""
    if inst_type == "raw":
        inst = read_raw_fif(fif_fname).crop(0, 1).pick(picks=["eeg", "stim"])
    # Test re-referencing Epochs object
    elif inst_type == "epochs":
        raw = read_raw_fif(fif_fname, preload=False)
        events = read_events(eve_fname)
        inst = Epochs(
            raw,
            events=events,
            event_id=1,
            tmin=-0.2,
            tmax=0.5,
            preload=False,
        )
    with pytest.raises(
        RuntimeError,
        match="By default, MNE does not load data.*Applying a reference requires.*",
    ):
        inst.set_eeg_reference(ref_channels=ref_channels)
    inst.load_data()
    inst.info["bads"] = ["EEG 057"]
    with expectation:
        reref, _ = set_eeg_reference(inst.copy(), ref_channels, copy=False)

    if isinstance(expectation, nullcontext):
        # Check that the custom_ref_applied is set correctly:
        assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON

        # Get raw data
        _data = inst._data

        # Get that channels that were and weren't re-referenced:
        ch_raw = pick_channels(
            inst.ch_names,
            [ch for ch in inst.ch_names if ch not in list(ref_channels.keys())],
        )
        ch_reref = pick_channels(inst.ch_names, list(ref_channels.keys()), ordered=True)

        # Check that the non re-reference channels are untouched:
        assert_allclose(
            _data[..., ch_raw, :], reref._data[..., ch_raw, :], 1e-6, atol=1e-15
        )

        # Compute the reference data:
        ref_data = []
        for val in ref_channels.values():
            if isinstance(val, str):
                val = [val]  # pick_channels expects a list
            ref_data.append(
                _data[..., pick_channels(inst.ch_names, val, ordered=True), :].mean(
                    -2, keepdims=True
                )
            )
        if inst_type == "epochs":
            ref_data = np.concatenate(ref_data, axis=1)
        else:
            ref_data = np.squeeze(np.array(ref_data))
        assert_allclose(
            _data[..., ch_reref, :],
            reref._data[..., ch_reref, :] + ref_data,
            1e-6,
            atol=1e-15,
        )


@testing.requires_testing_data
@pytest.mark.parametrize("inst_type", ("raw", "epochs", "evoked"))
def test_set_bipolar_reference(inst_type):
    """Test bipolar referencing."""
    raw = read_raw_fif(fif_fname, preload=True)
    raw.apply_proj()

    if inst_type == "raw":
        inst = raw
        del raw
    elif inst_type in ["epochs", "evoked"]:
        events = find_events(raw, stim_channel="STI 014")
        epochs = Epochs(raw, events, tmin=-0.3, tmax=0.7, preload=True)
        inst = epochs
        if inst_type == "evoked":
            inst = epochs.average()
        del epochs

    ch_info = {"kind": FIFF.FIFFV_EOG_CH, "extra": "some extra value"}
    with pytest.raises(KeyError, match="key errantly present"):
        set_bipolar_reference(inst, "EEG 001", "EEG 002", "bipolar", ch_info)
    ch_info.pop("extra")
    reref = set_bipolar_reference(inst, "EEG 001", "EEG 002", "bipolar", ch_info)
    assert reref.info["custom_ref_applied"]

    # Compare result to a manual calculation
    a = inst.copy().pick(["EEG 001", "EEG 002"])
    a = a._data[..., 0, :] - a._data[..., 1, :]
    b = reref.copy().pick(["bipolar"])._data[..., 0, :]
    assert_allclose(a, b)

    # Original channels should be replaced by a virtual one
    assert "EEG 001" not in reref.ch_names
    assert "EEG 002" not in reref.ch_names
    assert "bipolar" in reref.ch_names

    # Check channel information
    bp_info = reref.info["chs"][reref.ch_names.index("bipolar")]
    an_info = inst.info["chs"][inst.ch_names.index("EEG 001")]
    for key in bp_info:
        if key == "coil_type":
            assert bp_info[key] == FIFF.FIFFV_COIL_EEG_BIPOLAR, key
        elif key == "kind":
            assert bp_info[key] == FIFF.FIFFV_EOG_CH, key
        elif key != "ch_name":
            assert_equal(bp_info[key], an_info[key], err_msg=key)

    # Minimalist call
    reref = set_bipolar_reference(inst, "EEG 001", "EEG 002")
    assert "EEG 001-EEG 002" in reref.ch_names

    # Minimalist call with twice the same anode
    reref = set_bipolar_reference(
        inst, ["EEG 001", "EEG 001", "EEG 002"], ["EEG 002", "EEG 003", "EEG 003"]
    )
    assert "EEG 001-EEG 002" in reref.ch_names
    assert "EEG 001-EEG 003" in reref.ch_names

    # Set multiple references at once
    reref = set_bipolar_reference(
        inst,
        ["EEG 001", "EEG 003"],
        ["EEG 002", "EEG 004"],
        ["bipolar1", "bipolar2"],
        [{"kind": FIFF.FIFFV_EOG_CH}, {"kind": FIFF.FIFFV_EOG_CH}],
    )
    a = inst.copy().pick(["EEG 001", "EEG 002", "EEG 003", "EEG 004"])
    a = np.concatenate(
        [
            a._data[..., :1, :] - a._data[..., 1:2, :],
            a._data[..., 2:3, :] - a._data[..., 3:4, :],
        ],
        axis=-2,
    )
    b = reref.copy().pick(["bipolar1", "bipolar2"])._data
    assert_allclose(a, b)

    # Test creating a bipolar reference that doesn't involve EEG channels:
    # it should not set the custom_ref_applied flag
    reref = set_bipolar_reference(
        inst,
        "MEG 0111",
        "MEG 0112",
        ch_info={"kind": FIFF.FIFFV_MEG_CH},
        verbose="error",
    )
    assert not reref.info["custom_ref_applied"]
    assert "MEG 0111-MEG 0112" in reref.ch_names

    # Test a battery of invalid inputs
    pytest.raises(
        ValueError,
        set_bipolar_reference,
        inst,
        "EEG 001",
        ["EEG 002", "EEG 003"],
        "bipolar",
    )
    pytest.raises(
        ValueError,
        set_bipolar_reference,
        inst,
        ["EEG 001", "EEG 002"],
        "EEG 003",
        "bipolar",
    )
    pytest.raises(
        ValueError,
        set_bipolar_reference,
        inst,
        "EEG 001",
        "EEG 002",
        ["bipolar1", "bipolar2"],
    )
    pytest.raises(
        ValueError,
        set_bipolar_reference,
        inst,
        "EEG 001",
        "EEG 002",
        "bipolar",
        ch_info=[{"foo": "bar"}, {"foo": "bar"}],
    )
    pytest.raises(
        ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", ch_name="EEG 003"
    )

    # Test if bad anode/cathode raises error if on_bad="raise"
    inst.info["bads"] = ["EEG 001"]
    pytest.raises(
        ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="raise"
    )
    inst.info["bads"] = ["EEG 002"]
    pytest.raises(
        ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="raise"
    )

    # Test if bad anode/cathode raises warning if on_bad="warn"
    inst.info["bads"] = ["EEG 001"]
    pytest.warns(
        RuntimeWarning, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="warn"
    )
    inst.info["bads"] = ["EEG 002"]
    pytest.warns(
        RuntimeWarning, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="warn"
    )


def _check_channel_names(inst, ref_names):
    """Check channel names."""
    if isinstance(ref_names, str):
        ref_names = [ref_names]

    # Test that the names of the reference channels are present in `ch_names`
    ref_idx = pick_channels(inst.info["ch_names"], ref_names)
    assert len(ref_idx) == len(ref_names)

    # Test that the names of the reference channels are present in the `chs`
    # list
    inst.info._check_consistency()  # Should raise no exceptions


@testing.requires_testing_data
def test_add_reference():
    """Test adding a reference."""
    raw = read_raw_fif(fif_fname, preload=True)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    # check if channel already exists
    pytest.raises(ValueError, add_reference_channels, raw, raw.info["ch_names"][0])
    # add reference channel to Raw
    raw_ref = add_reference_channels(raw, "Ref", copy=True)
    assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1)
    assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
    _check_channel_names(raw_ref, "Ref")

    orig_nchan = raw.info["nchan"]
    raw = add_reference_channels(raw, "Ref", copy=False)
    assert_array_equal(raw._data, raw_ref._data)
    assert_equal(raw.info["nchan"], orig_nchan + 1)
    _check_channel_names(raw, "Ref")

    # for Neuromag fif's, the reference electrode location is placed in
    # elements [3:6] of each "data" electrode location
    assert_allclose(
        raw.info["chs"][-1]["loc"][:3], raw.info["chs"][picks_eeg[0]]["loc"][3:6], 1e-6
    )

    ref_idx = raw.ch_names.index("Ref")
    ref_data, _ = raw[ref_idx]
    assert_array_equal(ref_data, 0)

    # add reference channel to Raw when no digitization points exist
    raw = read_raw_fif(fif_fname).crop(0, 1).load_data()
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    del raw.info["dig"]

    raw_ref = add_reference_channels(raw, "Ref", copy=True)

    assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1)
    assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
    _check_channel_names(raw_ref, "Ref")

    orig_nchan = raw.info["nchan"]
    raw = add_reference_channels(raw, "Ref", copy=False)
    assert_array_equal(raw._data, raw_ref._data)
    assert_equal(raw.info["nchan"], orig_nchan + 1)
    _check_channel_names(raw, "Ref")

    # Test adding an existing channel as reference channel
    pytest.raises(ValueError, add_reference_channels, raw, raw.info["ch_names"][0])

    # add two reference channels to Raw
    raw_ref = add_reference_channels(raw, ["M1", "M2"], copy=True)
    _check_channel_names(raw_ref, ["M1", "M2"])
    assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 2)
    assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
    assert_array_equal(raw_ref._data[-2:, :], 0)

    raw = add_reference_channels(raw, ["M1", "M2"], copy=False)
    _check_channel_names(raw, ["M1", "M2"])
    ref_idx = raw.ch_names.index("M1")
    ref_idy = raw.ch_names.index("M2")
    ref_data, _ = raw[[ref_idx, ref_idy]]
    assert_array_equal(ref_data, 0)

    # add reference channel to epochs
    raw = read_raw_fif(fif_fname, preload=True)
    events = read_events(eve_fname)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    epochs = Epochs(
        raw,
        events=events,
        event_id=1,
        tmin=-0.2,
        tmax=0.5,
        picks=picks_eeg,
        preload=True,
    )
    # default: proj=True, after which adding a Ref channel is prohibited
    pytest.raises(RuntimeError, add_reference_channels, epochs, "Ref")

    # create epochs in delayed mode, allowing removal of CAR when re-reffing
    epochs = Epochs(
        raw,
        events=events,
        event_id=1,
        tmin=-0.2,
        tmax=0.5,
        picks=picks_eeg,
        preload=True,
        proj="delayed",
    )
    epochs_ref = add_reference_channels(epochs, "Ref", copy=True)

    assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 1)
    _check_channel_names(epochs_ref, "Ref")
    ref_idx = epochs_ref.ch_names.index("Ref")
    ref_data = epochs_ref.get_data(picks=[ref_idx])[:, 0]
    assert_array_equal(ref_data, 0)
    picks_eeg = pick_types(epochs.info, meg=False, eeg=True)
    assert_array_equal(epochs.get_data(picks_eeg), epochs_ref.get_data(picks_eeg))

    # add two reference channels to epochs
    raw = read_raw_fif(fif_fname, preload=True)
    events = read_events(eve_fname)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    # create epochs in delayed mode, allowing removal of CAR when re-reffing
    epochs = Epochs(
        raw,
        events=events,
        event_id=1,
        tmin=-0.2,
        tmax=0.5,
        picks=picks_eeg,
        preload=True,
        proj="delayed",
    )
    with pytest.warns(RuntimeWarning, match="for this channel is unknown or ambiguous"):
        epochs_ref = add_reference_channels(epochs, ["M1", "M2"], copy=True)
    assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 2)
    _check_channel_names(epochs_ref, ["M1", "M2"])
    ref_idx = epochs_ref.ch_names.index("M1")
    ref_idy = epochs_ref.ch_names.index("M2")
    assert_equal(epochs_ref.info["chs"][ref_idx]["ch_name"], "M1")
    assert_equal(epochs_ref.info["chs"][ref_idy]["ch_name"], "M2")
    ref_data = epochs_ref.get_data([ref_idx, ref_idy])
    assert_array_equal(ref_data, 0)
    picks_eeg = pick_types(epochs.info, meg=False, eeg=True)
    assert_array_equal(epochs.get_data(picks_eeg), epochs_ref.get_data(picks_eeg))

    # add reference channel to evoked
    raw = read_raw_fif(fif_fname, preload=True)
    events = read_events(eve_fname)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    # create epochs in delayed mode, allowing removal of CAR when re-reffing
    epochs = Epochs(
        raw,
        events=events,
        event_id=1,
        tmin=-0.2,
        tmax=0.5,
        picks=picks_eeg,
        preload=True,
        proj="delayed",
    )
    evoked = epochs.average()
    evoked_ref = add_reference_channels(evoked, "Ref", copy=True)
    assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 1)
    _check_channel_names(evoked_ref, "Ref")
    ref_idx = evoked_ref.ch_names.index("Ref")
    ref_data = evoked_ref.data[ref_idx, :]
    assert_array_equal(ref_data, 0)
    picks_eeg = pick_types(evoked.info, meg=False, eeg=True)
    assert_array_equal(evoked.data[picks_eeg, :], evoked_ref.data[picks_eeg, :])

    # add two reference channels to evoked
    raw = read_raw_fif(fif_fname, preload=True)
    events = read_events(eve_fname)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    # create epochs in delayed mode, allowing removal of CAR when re-reffing
    epochs = Epochs(
        raw,
        events=events,
        event_id=1,
        tmin=-0.2,
        tmax=0.5,
        picks=picks_eeg,
        preload=True,
        proj="delayed",
    )
    evoked = epochs.average()
    with pytest.warns(RuntimeWarning, match="for this channel is unknown or ambiguous"):
        evoked_ref = add_reference_channels(evoked, ["M1", "M2"], copy=True)
    assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 2)
    _check_channel_names(evoked_ref, ["M1", "M2"])
    ref_idx = evoked_ref.ch_names.index("M1")
    ref_idy = evoked_ref.ch_names.index("M2")
    ref_data = evoked_ref.data[[ref_idx, ref_idy], :]
    assert_array_equal(ref_data, 0)
    picks_eeg = pick_types(evoked.info, meg=False, eeg=True)
    assert_array_equal(evoked.data[picks_eeg, :], evoked_ref.data[picks_eeg, :])

    # Test invalid inputs
    raw = read_raw_fif(fif_fname, preload=False)
    with pytest.raises(RuntimeError, match="loaded"):
        add_reference_channels(raw, ["Ref"])
    raw.load_data()
    with pytest.raises(ValueError, match="Channel.*already.*"):
        add_reference_channels(raw, raw.ch_names[:1])
    with pytest.raises(TypeError, match="instance of"):
        add_reference_channels(raw, 1)

    # gh-10878
    raw = read_raw_fif(raw_fname).crop(0, 1, include_tmax=False).load_data()
    data = raw.copy().add_reference_channels(["REF"]).pick(picks="eeg")
    data = data.get_data()
    epochs = make_fixed_length_epochs(raw).load_data()
    data_2 = epochs.copy().add_reference_channels(["REF"]).pick(picks="eeg")
    data_2 = data_2.get_data(copy=False)[0]
    assert_allclose(data, data_2)
    evoked = epochs.average()
    data_3 = evoked.copy().add_reference_channels(["REF"]).pick(picks="eeg")
    data_3 = data_3.get_data()
    assert_allclose(data, data_3)


@pytest.mark.parametrize("n_ref", (1, 2))
def test_add_reorder(n_ref):
    """Test that a reference channel can be added and then data reordered."""
    # gh-8300
    raw = read_raw_fif(raw_fname).crop(0, 0.1).del_proj().pick("eeg")
    assert len(raw.ch_names) == 60
    chs = [f"EEG {60 + ii:03}" for ii in range(1, n_ref)] + ["EEG 000"]
    with pytest.raises(RuntimeError, match="preload"):
        with _record_warnings():  # ignore multiple warning
            add_reference_channels(raw, chs, copy=False)
    raw.load_data()
    if n_ref == 1:
        ctx = nullcontext()
    else:
        assert n_ref == 2
        ctx = pytest.warns(RuntimeWarning, match="this channel is unknown or ambiguous")
    with ctx:
        add_reference_channels(raw, chs, copy=False)
    data = raw.get_data()
    assert_array_equal(data[-1], 0.0)
    assert raw.ch_names[-n_ref:] == chs
    raw.reorder_channels(raw.ch_names[-1:] + raw.ch_names[:-1])
    assert raw.ch_names == [f"EEG {ii:03}" for ii in range(60 + n_ref)]
    data_new = raw.get_data()
    data_new = np.concatenate([data_new[1:], data_new[:1]])
    assert_allclose(data, data_new)


def test_bipolar_combinations():
    """Test bipolar channel generation."""
    ch_names = ["CH" + str(ni + 1) for ni in range(10)]
    info = create_info(
        ch_names=ch_names, sfreq=1000.0, ch_types=["eeg"] * len(ch_names)
    )
    raw_data = np.random.randn(len(ch_names), 1000)
    raw = RawArray(raw_data, info)

    def _check_bipolar(raw_test, ch_a, ch_b):
        picks = [raw_test.ch_names.index(ch_a + "-" + ch_b)]
        get_data_res = raw_test.get_data(picks=picks)[0, :]
        manual_a = raw_data[ch_names.index(ch_a), :]
        manual_b = raw_data[ch_names.index(ch_b), :]
        assert_array_equal(get_data_res, manual_a - manual_b)

    # test classic EOG/ECG bipolar reference (only two channels per pair).
    raw_test = set_bipolar_reference(raw, ["CH2"], ["CH1"], copy=True)
    _check_bipolar(raw_test, "CH2", "CH1")

    # test all combinations.
    a_channels, b_channels = zip(*itertools.combinations(ch_names, 2))
    a_channels, b_channels = list(a_channels), list(b_channels)
    raw_test = set_bipolar_reference(raw, a_channels, b_channels, copy=True)
    for ch_a, ch_b in zip(a_channels, b_channels):
        _check_bipolar(raw_test, ch_a, ch_b)
    # check if reference channels have been dropped.
    assert len(raw_test.ch_names) == len(a_channels)

    raw_test = set_bipolar_reference(
        raw, a_channels, b_channels, drop_refs=False, copy=True
    )
    # check if reference channels have been kept correctly.
    assert len(raw_test.ch_names) == len(a_channels) + len(ch_names)
    for idx, ch_label in enumerate(ch_names):
        manual_ch = raw_data[np.newaxis, idx]
        assert_array_equal(raw_test.get_data(ch_label), manual_ch)

    # test bipolars with a channel in both list (anode & cathode).
    raw_test = set_bipolar_reference(raw, ["CH2", "CH1"], ["CH1", "CH2"], copy=True)
    _check_bipolar(raw_test, "CH2", "CH1")
    _check_bipolar(raw_test, "CH1", "CH2")

    # test if bipolar channel is bad if anode is a bad channel
    raw.info["bads"] = ["CH1"]
    raw_test = set_bipolar_reference(
        raw, ["CH1"], ["CH2"], on_bad="ignore", ch_name="bad_bipolar", copy=True
    )
    assert raw_test.info["bads"] == ["bad_bipolar"]

    # test if bipolar channel is bad if cathode is a bad channel
    raw.info["bads"] = ["CH2"]
    raw_test = set_bipolar_reference(
        raw, ["CH1"], ["CH2"], on_bad="ignore", ch_name="bad_bipolar", copy=True
    )
    assert raw_test.info["bads"] == ["bad_bipolar"]
