# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

# NOTE: Tests in this directory must be self-contained because they are
# executed in a separate IPython kernel.

import pytest

from mne.datasets import testing


@testing.requires_testing_data
def test_notebook_alignment(renderer_notebook, brain_gc, nbexec):
    """Test plot alignment in a notebook."""
    import pytest

    import mne

    with pytest.MonkeyPatch().context() as mp:
        mp.delenv("_MNE_FAKE_HOME_DIR")
        data_path = mne.datasets.testing.data_path(download=False)
    raw_fname = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif"
    subjects_dir = data_path / "subjects"
    subject = "sample"
    trans = data_path / "MEG" / "sample" / "sample_audvis_trunc-trans.fif"
    info = mne.io.read_info(raw_fname)
    mne.viz.set_3d_backend("notebook")
    fig = mne.viz.plot_alignment(
        info,
        trans,
        subject=subject,
        dig=True,
        meg=["helmet", "sensors"],
        subjects_dir=subjects_dir,
        surfaces=["head-dense"],
    )
    assert fig.display is not None


@pytest.mark.slowtest  # ~3 min on GitHub macOS
@testing.requires_testing_data
def test_notebook_interactive(renderer_notebook, brain_gc, nbexec):
    """Test interactive modes."""
    import tempfile
    import time
    from contextlib import contextmanager
    from pathlib import Path

    import matplotlib.pyplot as plt
    import pytest
    from ipywidgets import Button
    from numpy.testing import assert_allclose

    import mne
    from mne.datasets import testing

    with pytest.MonkeyPatch().context() as mp:
        mp.delenv("_MNE_FAKE_HOME_DIR")
        data_path = testing.data_path(download=False)
    sample_dir = data_path / "MEG" / "sample"
    subjects_dir = data_path / "subjects"
    fname_stc = sample_dir / "sample_audvis_trunc-meg"
    stc = mne.read_source_estimate(fname_stc, subject="sample")
    stc.crop(0.1, 0.11)
    initial_time = 0.13
    mne.viz.set_3d_backend("notebook")
    brain_class = mne.viz.get_brain_class()

    @contextmanager
    def interactive(on):
        old = plt.isinteractive()
        plt.interactive(on)
        try:
            yield
        finally:
            plt.interactive(old)

    with interactive(False):
        brain = stc.plot(
            subjects_dir=subjects_dir,
            initial_time=initial_time,
            clim=dict(kind="value", pos_lims=[3, 6, 9]),
            time_viewer=True,
            show_traces=True,
            hemi="lh",
            size=300,
        )
        assert isinstance(brain, brain_class)
        assert brain._renderer.figure.notebook
        assert brain._renderer.figure.display is not None
        brain._renderer._update()
        tmp_path = Path(tempfile.mkdtemp())
        movie_path = tmp_path / "test.gif"
        screenshot_path = tmp_path / "test.png"
        actions = brain._renderer.actions
        assert actions["movie_field"]._action.value == ""
        actions["movie_field"]._action.value = str(movie_path)
        assert not movie_path.is_file()
        assert actions["screenshot_field"]._action.value == ""
        brain._renderer.actions["screenshot_field"]._action.value = str(screenshot_path)
        assert not screenshot_path.is_file()
        total_number_of_buttons = sum(
            "_field" not in k for k in brain._renderer.actions.keys()
        )
        assert "play" in brain._renderer.actions
        # play is not a button widget, it does not have a click() method
        number_of_buttons = 1
        button_names = list()
        for name, action in brain._renderer.actions.items():
            widget = action._action
            if isinstance(widget, Button):
                widget.click()
                button_names.append(name)
                number_of_buttons += 1
        assert number_of_buttons == total_number_of_buttons
        time.sleep(0.5)
        assert "movie" in button_names, button_names
        assert movie_path.is_file(), movie_path
        assert "screenshot" in button_names, button_names
        assert screenshot_path.is_file(), screenshot_path
        img_nv = brain.screenshot()
        assert img_nv.shape == (300, 300, 3), img_nv.shape
        img_v = brain.screenshot(time_viewer=True)
        assert img_v.shape[1:] == (300, 3), img_v.shape
        # XXX This rtol is not very good, ideally would be zero
        assert_allclose(
            img_v.shape[0], img_nv.shape[0] * 1.25, err_msg=img_nv.shape, rtol=0.1
        )
        brain.close()


@testing.requires_testing_data
def test_notebook_button_counts(renderer_notebook, brain_gc, nbexec):
    """Test button counts."""
    from ipywidgets import Button

    import mne

    mne.viz.set_3d_backend("notebook")
    rend = mne.viz.create_3d_figure(size=(100, 100), scene=False)
    fig = rend.scene()
    mne.viz.set_3d_title(fig, "Notebook testing")
    mne.viz.set_3d_view(fig, 200, 70, focalpoint=[0, 0, 0])
    assert fig.display is None
    rend.show()
    total_number_of_buttons = sum("_field" not in k for k in rend.actions.keys())
    number_of_buttons = 0
    for action in rend.actions.values():
        widget = action._action
        if isinstance(widget, Button):
            widget.click()
            number_of_buttons += 1
    assert number_of_buttons == total_number_of_buttons
    assert fig.display is not None
