# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from pathlib import Path

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from mne.parallel import parallel_func
from mne.utils import ProgressBar, array_split_idx, catch_logging, use_log_level


def test_progressbar(monkeypatch):
    """Test progressbar class."""
    a = np.arange(10)
    pbar = ProgressBar(a)
    assert a is pbar.iterable
    assert pbar.max_value == 10

    pbar = ProgressBar(10)
    assert pbar.max_value == 10
    assert pbar.iterable is None

    # Make sure that non-iterable input raises an error
    def iter_func(a):
        for ii in a:
            pass

    with pytest.raises(TypeError, match="not iterable"):
        iter_func(pbar)

    # Make sure different progress bars can be used
    monkeypatch.setenv("MNE_TQDM", "tqdm")
    with catch_logging("debug") as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert "Using ProgressBar with tqdm\n" in log
    monkeypatch.setenv("MNE_TQDM", "broken")
    with pytest.raises(ValueError, match="Invalid value for the"):
        ProgressBar(np.arange(3))
    monkeypatch.setenv("MNE_TQDM", "tqdm.broken")
    with pytest.raises(ValueError, match="Unknown tqdm"):
        ProgressBar(np.arange(3))
    # off
    monkeypatch.setenv("MNE_TQDM", "off")
    with catch_logging("debug") as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert "Using ProgressBar with off\n" == log


def _identity(x):
    return x


def test_progressbar_parallel_basic(capsys):
    """Test ProgressBar with parallel computing, basic version."""
    assert capsys.readouterr().out == ""
    parallel, p_fun, _ = parallel_func(_identity, total=10, n_jobs=1, verbose=True)
    with use_log_level(True):
        out = parallel(p_fun(x) for x in range(10))
    assert out == list(range(10))
    cap = capsys.readouterr()
    out = cap.err
    assert "100%" in out


def _identity_block(x, pb):
    for ii in range(len(x)):
        pb.update(ii + 1)
    return x


def test_progressbar_parallel_advanced(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ""
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block, n_jobs=1, verbose=False)
    arr = np.arange(10)
    with use_log_level(True):
        with ProgressBar(len(arr)) as pb:
            out = parallel(
                p_fun(x, pb.subset(pb_idx)) for pb_idx, x in array_split_idx(arr, 2)
            )
            assert Path(pb._mmap_fname).is_file()
            sum_ = np.memmap(pb._mmap_fname, dtype="bool", mode="r", shape=10).sum()
            assert sum_ == len(arr)
    assert not Path(pb._mmap_fname).is_file(), "__exit__ not called?"
    out = np.concatenate(out)
    assert_array_equal(out, arr)
    cap = capsys.readouterr()
    out = cap.err
    assert "100%" in out


def _identity_block_wide(x, pb):
    for ii in range(len(x)):
        for jj in range(2):
            pb.update(ii * 2 + jj + 1)
    return x, pb.idx


def test_progressbar_parallel_more(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ""
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block_wide, n_jobs=1, verbose=False)
    arr = np.arange(10)
    with use_log_level(True):
        with ProgressBar(len(arr) * 2) as pb:
            out = parallel(
                p_fun(x, pb.subset(pb_idx))
                for pb_idx, x in array_split_idx(arr, 2, n_per_split=2)
            )
            idxs = np.concatenate([o[1] for o in out])
            assert_array_equal(idxs, np.arange(len(arr) * 2))
            out = np.concatenate([o[0] for o in out])
            assert Path(pb._mmap_fname).is_file()
            sum_ = np.memmap(
                pb._mmap_fname, dtype="bool", mode="r", shape=len(arr) * 2
            ).sum()
            assert sum_ == len(arr) * 2
    assert not Path(pb._mmap_fname).is_file(), "__exit__ not called?"
    cap = capsys.readouterr()
    out = cap.err
    assert "100%" in out
