import asyncio
import threading
import unittest
from threading import Thread
from unittest import TestCase
import weakref
from test import support
from test.support import threading_helper

threading_helper.requires_working_threading(module=True)


class MyException(Exception):
    pass


def tearDownModule():
    asyncio.events._set_event_loop_policy(None)


class TestFreeThreading:
    def test_all_tasks_race(self) -> None:
        async def main():
            loop = asyncio.get_running_loop()
            future = loop.create_future()

            async def coro():
                await future

            tasks = set()

            async with asyncio.TaskGroup() as tg:
                for _ in range(100):
                    tasks.add(tg.create_task(coro()))

                all_tasks = asyncio.all_tasks(loop)
                self.assertEqual(len(all_tasks), 101)

                for task in all_tasks:
                    self.assertEqual(task.get_loop(), loop)
                    self.assertFalse(task.done())

                current = asyncio.current_task()
                self.assertEqual(current.get_loop(), loop)
                self.assertSetEqual(all_tasks, tasks | {current})
                future.set_result(None)

        def runner():
            with asyncio.Runner() as runner:
                loop = runner.get_loop()
                loop.set_task_factory(self.factory)
                runner.run(main())

        threads = []

        for _ in range(10):
            thread = Thread(target=runner)
            threads.append(thread)

        with threading_helper.start_threads(threads):
            pass

    def test_all_tasks_different_thread(self) -> None:
        loop = None
        started = threading.Event()
        done = threading.Event() # used for main task not finishing early
        async def coro():
            await asyncio.Future()

        lock = threading.Lock()
        tasks = set()

        async def main():
            nonlocal tasks, loop
            loop = asyncio.get_running_loop()
            started.set()
            for i in range(1000):
                with lock:
                    asyncio.create_task(coro())
                    tasks = asyncio.all_tasks(loop)
            done.wait()

        runner = threading.Thread(target=lambda: asyncio.run(main()))

        def check():
            started.wait()
            with lock:
                self.assertSetEqual(tasks & asyncio.all_tasks(loop), tasks)

        threads = [threading.Thread(target=check) for _ in range(10)]
        runner.start()

        with threading_helper.start_threads(threads):
            pass

        done.set()
        runner.join()

    def test_task_different_thread_finalized(self) -> None:
        task = None
        async def func():
            nonlocal task
            task = asyncio.current_task()
        def runner():
            with asyncio.Runner() as runner:
                loop = runner.get_loop()
                loop.set_task_factory(self.factory)
                runner.run(func())
        thread = Thread(target=runner)
        thread.start()
        thread.join()
        wr = weakref.ref(task)
        del thread
        del task
        # task finalization in different thread shouldn't crash
        support.gc_collect()
        self.assertIsNone(wr())

    def test_run_coroutine_threadsafe(self) -> None:
        results = []

        def in_thread(loop: asyncio.AbstractEventLoop):
            coro = asyncio.sleep(0.1, result=42)
            fut = asyncio.run_coroutine_threadsafe(coro, loop)
            result = fut.result()
            self.assertEqual(result, 42)
            results.append(result)

        async def main():
            loop = asyncio.get_running_loop()
            async with asyncio.TaskGroup() as tg:
                for _ in range(10):
                    tg.create_task(asyncio.to_thread(in_thread, loop))
            self.assertEqual(results, [42] * 10)

        with asyncio.Runner() as r:
            loop = r.get_loop()
            loop.set_task_factory(self.factory)
            r.run(main())

    def test_run_coroutine_threadsafe_exception(self) -> None:
        async def coro():
            await asyncio.sleep(0)
            raise MyException("test")

        def in_thread(loop: asyncio.AbstractEventLoop):
            fut = asyncio.run_coroutine_threadsafe(coro(), loop)
            return fut.result()

        async def main():
            loop = asyncio.get_running_loop()
            tasks = []
            for _ in range(10):
                task = loop.create_task(asyncio.to_thread(in_thread, loop))
                tasks.append(task)
            results = await asyncio.gather(*tasks, return_exceptions=True)

            self.assertEqual(len(results), 10)
            for result in results:
                self.assertIsInstance(result, MyException)
                self.assertEqual(str(result), "test")

        with asyncio.Runner() as r:
            loop = r.get_loop()
            loop.set_task_factory(self.factory)
            r.run(main())


class TestPyFreeThreading(TestFreeThreading, TestCase):

    def setUp(self):
        self._old_current_task = asyncio.current_task
        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
        self._old_all_tasks = asyncio.all_tasks
        asyncio.all_tasks = asyncio.tasks.all_tasks = asyncio.tasks._py_all_tasks
        self._old_Task = asyncio.Task
        asyncio.Task = asyncio.tasks.Task = asyncio.tasks._PyTask
        self._old_Future = asyncio.Future
        asyncio.Future = asyncio.futures.Future = asyncio.futures._PyFuture
        return super().setUp()

    def tearDown(self):
        asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
        asyncio.all_tasks = asyncio.tasks.all_tasks = self._old_all_tasks
        asyncio.Task = asyncio.tasks.Task = self._old_Task
        asyncio.Future = asyncio.tasks.Future = self._old_Future
        return super().tearDown()

    def factory(self, loop, coro, **kwargs):
        return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)


@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
class TestCFreeThreading(TestFreeThreading, TestCase):

    def setUp(self):
        self._old_current_task = asyncio.current_task
        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
        self._old_all_tasks = asyncio.all_tasks
        asyncio.all_tasks = asyncio.tasks.all_tasks = asyncio.tasks._c_all_tasks
        self._old_Task = asyncio.Task
        asyncio.Task = asyncio.tasks.Task = asyncio.tasks._CTask
        self._old_Future = asyncio.Future
        asyncio.Future = asyncio.futures.Future = asyncio.futures._CFuture
        return super().setUp()

    def tearDown(self):
        asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
        asyncio.all_tasks = asyncio.tasks.all_tasks = self._old_all_tasks
        asyncio.Task = asyncio.tasks.Task = self._old_Task
        asyncio.Future = asyncio.futures.Future = self._old_Future
        return super().tearDown()


    def factory(self, loop, coro, **kwargs):
        return asyncio.tasks._CTask(coro, loop=loop, **kwargs)


class TestEagerPyFreeThreading(TestPyFreeThreading):
    def factory(self, loop, coro, eager_start=True, **kwargs):
        return asyncio.tasks._PyTask(coro, loop=loop, **kwargs, eager_start=eager_start)


@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
    def factory(self, loop, coro, eager_start=True, **kwargs):
        return asyncio.tasks._CTask(coro, loop=loop, **kwargs, eager_start=eager_start)
