import concurrent.futures as cf
import functools
import itertools
import os
import pathlib
import platform
import sys
import time
import unittest
import warnings

import mpi4py.util.pool as pool
from mpi4py import MPI

try:
    import mpitestutil as testutil
except ImportError:
    sys.path.append(os.fspath(pathlib.Path(__file__).resolve().parent))
    import mpitestutil as testutil


mpi_vendor_name, mpi_vendor_version = MPI.get_vendor()


def sqr(x, wait=0.0):
    time.sleep(wait)
    return x * x


def mul(x, y):  # noqa: FURB118
    return x * y


def identity(x):
    return x


def raising():
    raise KeyError("key")


TIMEOUT1 = 0.1
TIMEOUT2 = 0.2


class TimingWrapper:
    #
    def __init__(self, func):
        self.func = func
        self.elapsed = None

    def __call__(self, *args, **kwds):
        t = time.monotonic()
        try:
            return self.func(*args, **kwds)
        finally:
            self.elapsed = time.monotonic() - t


class BaseTestPool:
    #
    PoolType = None

    @classmethod
    def Pool(cls, *args, **kwargs):
        if "coverage" in sys.modules:
            kwargs["python_args"] = "-m coverage run".split()
        Pool = cls.PoolType
        return Pool(*args, **kwargs)

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.pool = cls.Pool(1)

    @classmethod
    def tearDownClass(cls):
        cls.pool.terminate()
        cls.pool.join()
        cls.pool = None
        super().tearDownClass()

    @unittest.skipIf(
        mpi_vendor_name == "Open MPI"
        and mpi_vendor_version < (5, 0, 0)
        and platform.machine() == "s390x",
        "openmpi(<5.0.0)-s390x"
    )
    def test_apply(self):
        papply = self.pool.apply
        self.assertEqual(papply(sqr, (5,)), sqr(5))
        self.assertEqual(papply(sqr, (), {"x": 3}), sqr(x=3))

    def test_map(self):
        self.assertEqual(
            self.pool.map(sqr, range(10)), list(map(sqr, list(range(10))))
        )
        self.assertEqual(
            self.pool.map(sqr, (i for i in range(10))),
            list(map(sqr, list(range(10)))),
        )
        self.assertEqual(
            self.pool.map(sqr, list(range(10))),
            list(map(sqr, list(range(10)))),
        )

        self.assertEqual(
            self.pool.map(sqr, range(100), chunksize=20),
            list(map(sqr, list(range(100)))),
        )
        self.assertEqual(
            self.pool.map(sqr, (i for i in range(100)), chunksize=20),
            list(map(sqr, list(range(100)))),
        )
        self.assertEqual(
            self.pool.map(sqr, list(range(100)), chunksize=20),
            list(map(sqr, list(range(100)))),
        )

    def test_imap(self):
        self.assertEqual(
            list(self.pool.imap(sqr, range(10))),
            list(map(sqr, list(range(10)))),
        )
        self.assertEqual(
            list(self.pool.imap(sqr, (i for i in range(10)))),
            list(map(sqr, list(range(10)))),
        )
        self.assertEqual(
            list(self.pool.imap(sqr, list(range(10)))),
            list(map(sqr, list(range(10)))),
        )

        it = self.pool.imap(sqr, range(10))
        for i in range(10):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)
        it = self.pool.imap(sqr, list(range(10)))
        for i in range(10):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)

        it = self.pool.imap(sqr, range(100), chunksize=20)
        for i in range(100):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)
        it = self.pool.imap(sqr, list(range(100)), chunksize=20)
        for i in range(100):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)

    def test_imap_unordered(self):
        args = list(range(10))
        result = list(map(sqr, args))
        it = self.pool.imap_unordered(sqr, args)
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, iter(args))
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, (a for a in args))
        self.assertEqual(sorted(it), result)

        args = list(range(100))
        result = list(map(sqr, args))
        it = self.pool.imap_unordered(sqr, args, chunksize=20)
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, iter(args), chunksize=20)
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, (a for a in args), chunksize=20)
        self.assertEqual(sorted(it), result)

    def test_starmap(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        self.assertEqual(
            self.pool.starmap(mul, tuples),
            list(itertools.starmap(mul, tuples)),
        )
        tuples = list(zip(range(100), range(99, -1, -1)))
        self.assertEqual(
            self.pool.starmap(mul, tuples, chunksize=20),
            list(itertools.starmap(mul, tuples)),
        )

    def test_istarmap(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        result = list(itertools.starmap(mul, tuples))
        it = self.pool.istarmap(mul, tuples)
        self.assertEqual(list(it), result)
        iterator = zip(range(10), range(9, -1, -1))
        it = self.pool.istarmap(mul, iterator)
        self.assertEqual(list(it), result)

        tuples = list(zip(range(10), range(9, -1, -1)))
        it = self.pool.istarmap(mul, tuples)
        for i, j in tuples:
            self.assertEqual(next(it), i * j)
        self.assertRaises(StopIteration, next, it)

        tuples = list(zip(range(100), range(99, -1, -1)))
        it = self.pool.istarmap(mul, tuples, chunksize=20)
        for i, j in tuples:
            self.assertEqual(next(it), i * j)
        self.assertRaises(StopIteration, next, it)

    def test_istarmap_unordered(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        result = list(itertools.starmap(mul, tuples))
        it = self.pool.istarmap_unordered(mul, tuples)
        self.assertEqual(sorted(it), sorted(result))
        iterator = zip(range(10), range(9, -1, -1))
        it = self.pool.istarmap_unordered(mul, iterator)
        self.assertEqual(sorted(it), sorted(result))

        tuples = list(zip(range(100), range(99, -1, -1)))
        result = list(itertools.starmap(mul, tuples))
        it = self.pool.istarmap_unordered(mul, tuples, chunksize=20)
        self.assertEqual(sorted(it), sorted(result))

    def test_apply_async(self):
        res = self.pool.apply_async(sqr, (7,))
        self.assertEqual(res.get(), 49)

        res = self.pool.apply_async(sqr, (7, TIMEOUT2))
        get = TimingWrapper(res.get)
        self.assertEqual(get(), 49)
        self.assertLess(get.elapsed, TIMEOUT2 * 10)
        self.assertGreater(get.elapsed, TIMEOUT2 / 10)

    def test_apply_async_timeout(self):
        res = self.pool.apply_async(sqr, (7, TIMEOUT2))
        self.assertFalse(res.ready())
        self.assertRaises(ValueError, res.successful)
        res.wait(TIMEOUT2 / 100)
        self.assertFalse(res.ready())
        self.assertRaises(ValueError, res.successful)
        self.assertRaises(TimeoutError, res.get, TIMEOUT2 / 100)
        res.wait()
        self.assertTrue(res.ready())
        self.assertTrue(res.successful())
        self.assertEqual(res.get(), 49)

    def test_map_async(self):
        args = list(range(10))
        self.assertEqual(
            self.pool.map_async(sqr, args).get(), list(map(sqr, args))
        )
        args = list(range(100))
        self.assertEqual(
            self.pool.map_async(sqr, args, chunksize=20).get(),
            list(map(sqr, args)),
        )

    def test_map_async_callbacks(self):
        call_args = []
        result = self.pool.map_async(
            int,
            ["1", "2"],
            callback=call_args.append,
            error_callback=call_args.append,
        )
        result.wait()
        self.assertTrue(result.successful())
        self.assertEqual(len(call_args), 1)
        self.assertEqual(call_args[0], [1, 2])
        result = self.pool.map_async(
            int,
            ["a"],
            callback=call_args.append,
            error_callback=call_args.append,
        )
        result.wait()
        self.assertFalse(result.successful())
        self.assertEqual(len(call_args), 2)
        self.assertIsInstance(call_args[1], ValueError)

    def test_starmap_async(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        self.assertEqual(
            self.pool.starmap_async(mul, tuples).get(),
            list(itertools.starmap(mul, tuples)),
        )
        tuples = list(zip(range(1000), range(999, -1, -1)))
        self.assertEqual(
            self.pool.starmap_async(mul, tuples, chunksize=100).get(),
            list(itertools.starmap(mul, tuples)),
        )

    # ---

    def test_terminate(self):
        p = self.Pool(1)
        for _ in range(100):
            p.apply_async(time.sleep, (TIMEOUT1,))
        result = p.apply_async(time.sleep, (TIMEOUT1,))
        p.terminate()
        p.join()
        result.wait(TIMEOUT1)
        result.wait()
        self.assertFalse(result.successful())
        self.assertRaises(Exception, result.get, TIMEOUT1)
        self.assertRaises(Exception, result.get)

        p = self.Pool(1)
        args = [TIMEOUT1] * 100
        result = p.map_async(time.sleep, args, chunksize=1)
        p.terminate()
        p.join()
        result.wait(TIMEOUT1)
        result.wait()
        self.assertFalse(result.successful())
        self.assertRaises(Exception, result.get, TIMEOUT1)
        self.assertRaises(Exception, result.get)

    def test_empty_iterable(self):
        p = self.Pool(1)

        self.assertEqual(p.map(sqr, []), [])
        self.assertEqual(list(p.imap(sqr, [])), [])
        self.assertEqual(list(p.imap_unordered(sqr, [])), [])

        self.assertEqual(p.starmap(sqr, []), [])
        self.assertEqual(list(p.istarmap(sqr, [])), [])
        self.assertEqual(list(p.istarmap_unordered(sqr, [])), [])

        self.assertEqual(p.map_async(sqr, []).get(), [])
        self.assertEqual(p.starmap_async(mul, []).get(), [])

        p.close()
        p.join()

    def test_enter_exit(self):
        pool = self.Pool(1)

        with pool:
            pass
        # with self.assertRaises(ValueError):
        #     with pool:
        #         pass

        pool.join()

    # ---

    def test_async_error_callback(self):
        p = self.Pool(1)
        scratchpad = [None]

        def errback(exc):
            scratchpad[0] = exc

        res = p.apply_async(raising, error_callback=errback)
        p.close()
        p.join()
        self.assertRaises(KeyError, res.get)
        self.assertTrue(scratchpad[0])
        self.assertIsInstance(scratchpad[0], KeyError)

    def test_pool_worker_lifetime_early_close(self):
        p = self.Pool(1)
        results = []
        for i in range(5):
            results.append(p.apply_async(sqr, (i, TIMEOUT1)))
        p.close()
        p.join()
        for j, res in enumerate(results):
            self.assertEqual(res.get(), sqr(j))

    # ---

    def test_arg_processes(self):
        with self.assertRaises(ValueError):
            self.Pool(-1)
        with self.assertRaises(ValueError):
            self.Pool(0)

    def test_arg_initializer(self):
        self.Pool(1, initializer=identity, initargs=(123,))
        with self.assertRaises(TypeError):
            self.Pool(initializer=123)

    def test_unsupported_args(self):
        with warnings.catch_warnings():
            warnings.simplefilter("error")
            with self.assertRaises(UserWarning):
                with self.Pool(1, maxtasksperchild=1):
                    pass
            with self.assertRaises(UserWarning):
                with self.Pool(1, context=object):
                    pass


# ---


@unittest.skipIf(testutil.disable_mpi_spawn(), "mpi-spawn")
@unittest.skipIf(MPI.COMM_WORLD.Get_size() > 1, "mpi-world-size>1")
class TestProcessPool(BaseTestPool, unittest.TestCase):
    #
    PoolType = pool.Pool


class TestThreadPool(BaseTestPool, unittest.TestCase):
    #
    PoolType = pool.ThreadPool


# ---


class ExtraExecutorMixing:
    #
    def map(
        self,
        fn,
        iterable,
        timeout=None,
        chunksize=1,
        unordered=False,
    ):
        del unordered  # ignored, unused
        return super().map(
            fn,
            iterable,
            timeout=timeout,
            chunksize=chunksize,
        )

    def starmap(
        self,
        fn,
        iterable,
        timeout=None,
        chunksize=1,
        unordered=False,
    ):
        del unordered  # ignored, unused
        fn = functools.partial(self._apply_args, fn)
        return super().map(
            fn,
            iterable,
            timeout=timeout,
            chunksize=chunksize,
        )

    @staticmethod
    def _apply_args(fn, args):
        return fn(*args)


class ExtraExecutor(ExtraExecutorMixing, cf.ThreadPoolExecutor):
    #
    pass


class ExtraPool(pool.Pool):
    #
    Executor = ExtraExecutor


class TestExtraPool(BaseTestPool, unittest.TestCase):
    #
    PoolType = ExtraPool

    @classmethod
    def Pool(cls, *args, **kwargs):
        return cls.PoolType(*args, **kwargs)


del TestExtraPool


# ---

if __name__ == "__main__":
    unittest.main()
