import concurrent.futures as cf
import functools
import itertools
import os
import pathlib
import sys
import time
import unittest
import warnings

import mpi4py.util.pool as pool
from mpi4py import MPI

try:
    import mpitestutil as testutil
except ImportError:
    sys.path.append(os.fspath(pathlib.Path(__file__).resolve().parent))
    import mpitestutil as testutil

import platform

def sqr(x, wait=0.0):
    time.sleep(wait)
    return x * x


def mul(x, y):  # noqa: FURB118
    return x * y


def identity(x):
    return x


def raising():
    raise KeyError("key")


TIMEOUT1 = 0.1
TIMEOUT2 = 0.2


class TimingWrapper:
    #
    def __init__(self, func):
        self.func = func
        self.elapsed = None

    def __call__(self, *args, **kwds):
        t = time.monotonic()
        try:
            return self.func(*args, **kwds)
        finally:
            self.elapsed = time.monotonic() - t


class BaseTestPool:
    #
    PoolType = None

    @classmethod
    def Pool(cls, *args, **kwargs):
        if "coverage" in sys.modules:
            kwargs["python_args"] = "-m coverage run".split()
        Pool = cls.PoolType
        return Pool(*args, **kwargs)

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.pool = cls.Pool(1)

    @classmethod
    def tearDownClass(cls):
        cls.pool.terminate()
        cls.pool.join()
        cls.pool = None
        super().tearDownClass()

    @unittest.skipIf('s390x' in platform.machine(), "test_apply fails on s390x (with OpenMPI 4)")
    def test_apply(self):
        papply = self.pool.apply
        self.assertEqual(papply(sqr, (5,)), sqr(5))
        self.assertEqual(papply(sqr, (), {"x": 3}), sqr(x=3))

    def test_map(self):
        self.assertEqual(
            self.pool.map(sqr, range(10)), list(map(sqr, list(range(10))))
        )
        self.assertEqual(
            self.pool.map(sqr, (i for i in range(10))),
            list(map(sqr, list(range(10)))),
        )
        self.assertEqual(
            self.pool.map(sqr, list(range(10))),
            list(map(sqr, list(range(10)))),
        )

        self.assertEqual(
            self.pool.map(sqr, range(100), chunksize=20),
            list(map(sqr, list(range(100)))),
        )
        self.assertEqual(
            self.pool.map(sqr, (i for i in range(100)), chunksize=20),
            list(map(sqr, list(range(100)))),
        )
        self.assertEqual(
            self.pool.map(sqr, list(range(100)), chunksize=20),
            list(map(sqr, list(range(100)))),
        )

    def test_imap(self):
        self.assertEqual(
            list(self.pool.imap(sqr, range(10))),
            list(map(sqr, list(range(10)))),
        )
        self.assertEqual(
            list(self.pool.imap(sqr, (i for i in range(10)))),
            list(map(sqr, list(range(10)))),
        )
        self.assertEqual(
            list(self.pool.imap(sqr, list(range(10)))),
            list(map(sqr, list(range(10)))),
        )

        it = self.pool.imap(sqr, range(10))
        for i in range(10):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)
        it = self.pool.imap(sqr, list(range(10)))
        for i in range(10):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)

        it = self.pool.imap(sqr, range(100), chunksize=20)
        for i in range(100):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)
        it = self.pool.imap(sqr, list(range(100)), chunksize=20)
        for i in range(100):
            self.assertEqual(next(it), i * i)
        self.assertRaises(StopIteration, next, it)

    def test_imap_unordered(self):
        args = list(range(10))
        result = list(map(sqr, args))
        it = self.pool.imap_unordered(sqr, args)
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, iter(args))
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, (a for a in args))
        self.assertEqual(sorted(it), result)

        args = list(range(100))
        result = list(map(sqr, args))
        it = self.pool.imap_unordered(sqr, args, chunksize=20)
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, iter(args), chunksize=20)
        self.assertEqual(sorted(it), result)
        it = self.pool.imap_unordered(sqr, (a for a in args), chunksize=20)
        self.assertEqual(sorted(it), result)

    def test_starmap(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        self.assertEqual(
            self.pool.starmap(mul, tuples),
            list(itertools.starmap(mul, tuples)),
        )
        tuples = list(zip(range(100), range(99, -1, -1)))
        self.assertEqual(
            self.pool.starmap(mul, tuples, chunksize=20),
            list(itertools.starmap(mul, tuples)),
        )

    def test_istarmap(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        result = list(itertools.starmap(mul, tuples))
        it = self.pool.istarmap(mul, tuples)
        self.assertEqual(list(it), result)
        iterator = zip(range(10), range(9, -1, -1))
        it = self.pool.istarmap(mul, iterator)
        self.assertEqual(list(it), result)

        tuples = list(zip(range(10), range(9, -1, -1)))
        it = self.pool.istarmap(mul, tuples)
        for i, j in tuples:
            self.assertEqual(next(it), i * j)
        self.assertRaises(StopIteration, next, it)

        tuples = list(zip(range(100), range(99, -1, -1)))
        it = self.pool.istarmap(mul, tuples, chunksize=20)
        for i, j in tuples:
            self.assertEqual(next(it), i * j)
        self.assertRaises(StopIteration, next, it)

    def test_istarmap_unordered(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        result = list(itertools.starmap(mul, tuples))
        it = self.pool.istarmap_unordered(mul, tuples)
        self.assertEqual(sorted(it), sorted(result))
        iterator = zip(range(10), range(9, -1, -1))
        it = self.pool.istarmap_unordered(mul, iterator)
        self.assertEqual(sorted(it), sorted(result))

        tuples = list(zip(range(100), range(99, -1, -1)))
        result = list(itertools.starmap(mul, tuples))
        it = self.pool.istarmap_unordered(mul, tuples, chunksize=20)
        self.assertEqual(sorted(it), sorted(result))

    def test_apply_async(self):
        res = self.pool.apply_async(sqr, (7,))
        self.assertEqual(res.get(), 49)

        res = self.pool.apply_async(sqr, (7, TIMEOUT2))
        get = TimingWrapper(res.get)
        self.assertEqual(get(), 49)
        self.assertLess(get.elapsed, TIMEOUT2 * 10)
        self.assertGreater(get.elapsed, TIMEOUT2 / 10)

    def test_apply_async_timeout(self):
        res = self.pool.apply_async(sqr, (7, TIMEOUT2))
        self.assertFalse(res.ready())
        self.assertRaises(ValueError, res.successful)
        res.wait(TIMEOUT2 / 100)
        self.assertFalse(res.ready())
        self.assertRaises(ValueError, res.successful)
        self.assertRaises(TimeoutError, res.get, TIMEOUT2 / 100)
        res.wait()
        self.assertTrue(res.ready())
        self.assertTrue(res.successful())
        self.assertEqual(res.get(), 49)

    def test_map_async(self):
        args = list(range(10))
        self.assertEqual(
            self.pool.map_async(sqr, args).get(), list(map(sqr, args))
        )
        args = list(range(100))
        self.assertEqual(
            self.pool.map_async(sqr, args, chunksize=20).get(),
            list(map(sqr, args)),
        )

    def test_map_async_callbacks(self):
        call_args = []
        result = self.pool.map_async(
            int,
            ["1", "2"],
            callback=call_args.append,
            error_callback=call_args.append,
        )
        result.wait()
        self.assertTrue(result.successful())
        self.assertEqual(len(call_args), 1)
        self.assertEqual(call_args[0], [1, 2])
        result = self.pool.map_async(
            int,
            ["a"],
            callback=call_args.append,
            error_callback=call_args.append,
        )
        result.wait()
        self.assertFalse(result.successful())
        self.assertEqual(len(call_args), 2)
        self.assertIsInstance(call_args[1], ValueError)

    def test_starmap_async(self):
        tuples = list(zip(range(10), range(9, -1, -1)))
        self.assertEqual(
            self.pool.starmap_async(mul, tuples).get(),
            list(itertools.starmap(mul, tuples)),
        )
        tuples = list(zip(range(1000), range(999, -1, -1)))
        self.assertEqual(
            self.pool.starmap_async(mul, tuples, chunksize=100).get(),
            list(itertools.starmap(mul, tuples)),
        )

    # ---

    def test_terminate(self):
        p = self.Pool(1)
        for _ in range(100):
            p.apply_async(time.sleep, (TIMEOUT1,))
        result = p.apply_async(time.sleep, (TIMEOUT1,))
        p.terminate()
        p.join()
        result.wait(TIMEOUT1)
        result.wait()
        self.assertFalse(result.successful())
        self.assertRaises(Exception, result.get, TIMEOUT1)
        self.assertRaises(Exception, result.get)

        p = self.Pool(1)
        args = [TIMEOUT1] * 100
        result = p.map_async(time.sleep, args, chunksize=1)
        p.terminate()
        p.join()
        result.wait(TIMEOUT1)
        result.wait()
        self.assertFalse(result.successful())
        self.assertRaises(Exception, result.get, TIMEOUT1)
        self.assertRaises(Exception, result.get)

    def test_empty_iterable(self):
        p = self.Pool(1)

        self.assertEqual(p.map(sqr, []), [])
        self.assertEqual(list(p.imap(sqr, [])), [])
        self.assertEqual(list(p.imap_unordered(sqr, [])), [])

        self.assertEqual(p.starmap(sqr, []), [])
        self.assertEqual(list(p.istarmap(sqr, [])), [])
        self.assertEqual(list(p.istarmap_unordered(sqr, [])), [])

        self.assertEqual(p.map_async(sqr, []).get(), [])
        self.assertEqual(p.starmap_async(mul, []).get(), [])

        p.close()
        p.join()

    def test_enter_exit(self):
        pool = self.Pool(1)

        with pool:
            pass
        # with self.assertRaises(ValueError):
        #     with pool:
        #         pass

        pool.join()

    # ---

    def test_async_error_callback(self):
        p = self.Pool(1)
        scratchpad = [None]

        def errback(exc):
            scratchpad[0] = exc

        res = p.apply_async(raising, error_callback=errback)
        p.close()
        p.join()
        self.assertRaises(KeyError, res.get)
        self.assertTrue(scratchpad[0])
        self.assertIsInstance(scratchpad[0], KeyError)

    def test_pool_worker_lifetime_early_close(self):
        p = self.Pool(1)
        results = []
        for i in range(5):
            results.append(p.apply_async(sqr, (i, TIMEOUT1)))
        p.close()
        p.join()
        for j, res in enumerate(results):
            self.assertEqual(res.get(), sqr(j))

    # ---

    def test_arg_processes(self):
        with self.assertRaises(ValueError):
            self.Pool(-1)
        with self.assertRaises(ValueError):
            self.Pool(0)

    def test_arg_initializer(self):
        self.Pool(1, initializer=identity, initargs=(123,))
        with self.assertRaises(TypeError):
            self.Pool(initializer=123)

    def test_unsupported_args(self):
        with warnings.catch_warnings():
            warnings.simplefilter("error")
            with self.assertRaises(UserWarning):
                with self.Pool(1, maxtasksperchild=1):
                    pass
            with self.assertRaises(UserWarning):
                with self.Pool(1, context=object):
                    pass


# ---


@unittest.skipIf(testutil.disable_mpi_spawn(), "mpi-spawn")
@unittest.skipIf(MPI.COMM_WORLD.Get_size() > 1, "mpi-world-size>1")
class TestProcessPool(BaseTestPool, unittest.TestCase):
    #
    PoolType = pool.Pool


class TestThreadPool(BaseTestPool, unittest.TestCase):
    #
    PoolType = pool.ThreadPool


# ---


class ExtraExecutorMixing:
    #
    def map(
        self,
        fn,
        iterable,
        timeout=None,
        chunksize=1,
        unordered=False,
    ):
        del unordered  # ignored, unused
        return super().map(
            fn,
            iterable,
            timeout=timeout,
            chunksize=chunksize,
        )

    def starmap(
        self,
        fn,
        iterable,
        timeout=None,
        chunksize=1,
        unordered=False,
    ):
        del unordered  # ignored, unused
        fn = functools.partial(self._apply_args, fn)
        return super().map(
            fn,
            iterable,
            timeout=timeout,
            chunksize=chunksize,
        )

    @staticmethod
    def _apply_args(fn, args):
        return fn(*args)


class ExtraExecutor(ExtraExecutorMixing, cf.ThreadPoolExecutor):
    #
    pass


class ExtraPool(pool.Pool):
    #
    Executor = ExtraExecutor


class TestExtraPool(BaseTestPool, unittest.TestCase):
    #
    PoolType = ExtraPool

    @classmethod
    def Pool(cls, *args, **kwargs):
        return cls.PoolType(*args, **kwargs)


del TestExtraPool


# ---

if __name__ == "__main__":
    unittest.main()
