# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the Load Balancer unified spec tests."""
from __future__ import annotations

import asyncio
import gc
import os
import pathlib
import sys
import threading
from asyncio import Event
from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask
from test.asynchronous.utils import async_get_pool

import pytest

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous.unified_format import generate_test_classes
from test.utils_shared import (
    async_wait_until,
    create_async_event,
)

from pymongo.asynchronous.helpers import anext

_IS_SYNC = False

pytestmark = pytest.mark.load_balancer

# Location of JSON test specifications.
if _IS_SYNC:
    _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer")
else:
    _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer")

# Generate unified tests.
globals().update(generate_test_classes(_TEST_PATH, module=__name__))


class TestLB(AsyncIntegrationTest):
    RUN_ON_LOAD_BALANCER = True

    async def test_connections_are_only_returned_once(self):
        if "PyPy" in sys.version:
            # Tracked in PYTHON-3011
            self.skipTest("Test is flaky on PyPy")
        pool = await async_get_pool(self.client)
        n_conns = len(pool.conns)
        await self.db.test.find_one({})
        self.assertEqual(len(pool.conns), n_conns)
        await (await self.db.test.aggregate([{"$limit": 1}])).to_list()
        self.assertEqual(len(pool.conns), n_conns)

    @async_client_context.require_load_balancer
    async def test_unpin_committed_transaction(self):
        client = await self.async_rs_client()
        pool = await async_get_pool(client)
        coll = client[self.db.name].test
        async with client.start_session() as session:
            async with await session.start_transaction():
                self.assertEqual(pool.active_sockets, 0)
                await coll.insert_one({}, session=session)
                self.assertEqual(pool.active_sockets, 1)  # Pinned.
            self.assertEqual(pool.active_sockets, 1)  # Still pinned.
        self.assertEqual(pool.active_sockets, 0)  # Unpinned.

    @async_client_context.require_failCommand_fail_point
    async def test_cursor_gc(self):
        async def create_resource(coll):
            cursor = coll.find({}, batch_size=3)
            await anext(cursor)
            return cursor

        await self._test_no_gc_deadlock(create_resource)

    @async_client_context.require_failCommand_fail_point
    async def test_command_cursor_gc(self):
        async def create_resource(coll):
            cursor = await coll.aggregate([], batchSize=3)
            await anext(cursor)
            return cursor

        await self._test_no_gc_deadlock(create_resource)

    async def _test_no_gc_deadlock(self, create_resource):
        client = await self.async_rs_client()
        pool = await async_get_pool(client)
        coll = client[self.db.name].test
        await coll.insert_many([{} for _ in range(10)])
        self.assertEqual(pool.active_sockets, 0)
        # Cause the initial find attempt to fail to induce a reference cycle.
        args = {
            "mode": {"times": 1},
            "data": {
                "failCommands": ["find", "aggregate"],
                "closeConnection": True,
            },
        }
        async with self.fail_point(args):
            resource = await create_resource(coll)
            if async_client_context.load_balancer:
                self.assertEqual(pool.active_sockets, 1)  # Pinned.

        task = PoolLocker(pool)
        await task.start()
        self.assertTrue(await task.wait(task.locked, 5), "timed out")
        # Garbage collect the resource while the pool is locked to ensure we
        # don't deadlock.
        del resource
        # On PyPy it can take a few rounds to collect the cursor.
        for _ in range(3):
            gc.collect()
        task.unlock.set()
        await task.join(5)
        self.assertFalse(task.is_alive())
        self.assertIsNone(task.exc)

        await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
        # Run another operation to ensure the socket still works.
        await coll.delete_many({})

    @async_client_context.require_transactions
    async def test_session_gc(self):
        client = await self.async_rs_client()
        pool = await async_get_pool(client)
        session = client.start_session()
        await session.start_transaction()
        await client.test_session_gc.test.find_one({}, session=session)
        # Cleanup the transaction left open on the server
        self.addAsyncCleanup(self.client.admin.command, "killSessions", [session.session_id])
        if async_client_context.load_balancer:
            self.assertEqual(pool.active_sockets, 1)  # Pinned.

        task = PoolLocker(pool)
        await task.start()
        self.assertTrue(await task.wait(task.locked, 5), "timed out")
        # Garbage collect the session while the pool is locked to ensure we
        # don't deadlock.
        del session
        # On PyPy it can take a few rounds to collect the session.
        for _ in range(3):
            gc.collect()
        task.unlock.set()
        await task.join(5)
        self.assertFalse(task.is_alive())
        self.assertIsNone(task.exc)

        await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
        # Run another operation to ensure the socket still works.
        await client[self.db.name].test.delete_many({})


class PoolLocker(ExceptionCatchingTask):
    def __init__(self, pool):
        super().__init__(target=self.lock_pool)
        self.pool = pool
        self.daemon = True
        self.locked = create_async_event()
        self.unlock = create_async_event()

    async def lock_pool(self):
        async with self.pool.lock:
            self.locked.set()
            # Wait for the unlock flag.
            unlock_pool = await self.wait(self.unlock, 10)
            if not unlock_pool:
                raise Exception("timed out waiting for unlock signal: deadlock?")

    async def wait(self, event: Event, timeout: int):
        if _IS_SYNC:
            return event.wait(timeout)  # type: ignore[call-arg]
        else:
            try:
                await asyncio.wait_for(event.wait(), timeout=timeout)
            except asyncio.TimeoutError:
                return False
            return True


if __name__ == "__main__":
    unittest.main()
