# Copyright (c) 2023-2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Utils for the ANTA benchmark tests."""

from __future__ import annotations

import asyncio
import copy
import importlib
import json
import pkgutil
from typing import TYPE_CHECKING, Any

import httpx

from anta.catalog import AntaCatalog, AntaTestDefinition
from anta.models import AntaCommand, AntaTest

if TYPE_CHECKING:
    from collections.abc import Generator
    from types import ModuleType

    from anta.device import AntaDevice


async def collect(self: AntaTest) -> None:
    """Patched anta.models.AntaTest.collect() method.

    When generating the catalog, we inject a unit test case name in the custom_field input to be able to retrieve the eos_data for this specific test.
    We use this unit test case name in the eAPI request ID.
    """
    if self.inputs.result_overwrite is None or self.inputs.result_overwrite.custom_field is None:
        msg = f"The custom_field input is not present for test {self.name}"
        raise RuntimeError(msg)
    await self.device.collect_commands(self.instance_commands, collection_id=f"{self.name}:{self.inputs.result_overwrite.custom_field}")


async def collect_commands(self: AntaDevice, commands: list[AntaCommand], collection_id: str) -> None:
    """Patched anta.device.AntaDevice.collect_commands() method.

    For the same reason as above, we inject the command index of the test to the eAPI request ID.
    """
    await asyncio.gather(*(self.collect(command=command, collection_id=f"{collection_id}:{idx}") for idx, command in enumerate(commands)))


class AntaMockEnvironment:  # pylint: disable=too-few-public-methods
    """Generate an ANTA test catalog from the unit tests data. It can be accessed using the `catalog` attribute of this class instance.

    Also provide the attribute 'eos_data_catalog` with the output of all the commands used in the test catalog.

    Each module in `tests.units.anta_tests` has a `DATA` constant.

    The `DATA` structure is a dictionary where:
        - Each key is a tuple of size 2 containing:
            - An AntaTest subclass imported in the test module as first element - e.g. VerifyUptime.
            - A string used as name displayed by pytest as second element.
        - Each value is an instance of AntaUnitTest, which is a Python TypedDict.

    And AntaUnitTest have the following keys:
        - `eos_data` (list[dict]): List of data mocking EOS returned data to be passed to the test.
        - `inputs` (dict): Dictionary to instantiate the `test` inputs as defined in the class from `test`.
        - `expected` (dict): Expected test result structure, a dictionary containing a key `result` containing one of the allowed status
        (`Literal[AntaTestStatus.SUCCESS, AntaTestStatus.FAILURE, AntaTestStatus.SKIPPED]`) and
        optionally a key `messages` which is a list(str) and each message is expected to be a substring of one of the actual messages in the TestResult object.

    The keys of `eos_data_catalog` is the tuple (AntaTest subclass, A string used as name displayed by pytest). The values are `eos_data`.
    """

    def __init__(self) -> None:
        self._catalog, self.eos_data_catalog = self._generate_catalog()
        self.tests_count = len(self._catalog.tests)

    @property
    def catalog(self) -> AntaCatalog:
        """AntaMockEnvironment object will always return a new AntaCatalog object based on the initial parsing.

        This is because AntaCatalog objects store indexes when tests are run and we want a new object each time a test is run.
        """
        return copy.deepcopy(self._catalog)

    def _generate_catalog(self) -> tuple[AntaCatalog, dict[tuple[str, str], list[dict[str, Any]]]]:
        """Generate the `catalog` and `eos_data_catalog` attributes."""

        def import_test_modules() -> Generator[ModuleType, None, None]:
            """Yield all test modules from the given package."""
            package = importlib.import_module("tests.units.anta_tests")
            prefix = package.__name__ + "."
            for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix):
                if not is_pkg and module_name.split(".")[-1].startswith("test_"):
                    module = importlib.import_module(module_name)
                    if hasattr(module, "DATA"):
                        yield module

        test_definitions = []
        eos_data_catalog = {}
        for module in import_test_modules():
            for (test, name), test_data in module.DATA.items():
                # Extract the test class, name and test data from a nested tuple structure:
                # unit test: Tuple[Tuple[Type[AntaTest], str], AntaUnitTest]
                result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=name)
                if test_data.get("inputs") is None:
                    inputs = test.Input(result_overwrite=result_overwrite)
                else:
                    inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite)
                test_definition = AntaTestDefinition(
                    test=test,
                    inputs=inputs,
                )
                eos_data_catalog[(test.__name__, name)] = test_data["eos_data"]
                test_definitions.append(test_definition)

        return (AntaCatalog(tests=test_definitions), eos_data_catalog)

    def eapi_response(self, request: httpx.Request) -> httpx.Response:
        """Mock eAPI response.

        If the eAPI request ID has the format `ANTA-{test name}:{unit test name}:{command index}-{command ID}`,
        the function will return the eos_data from the unit test case.

        Otherwise, it will mock 'show version' command or raise an Exception.
        """
        words_count = 3

        def parse_req_id(req_id: str) -> tuple[str, str, int] | None:
            """Parse the patched request ID from the eAPI request."""
            req_id = req_id.removeprefix("ANTA-").rpartition("-")[0]
            words = req_id.split(":", words_count)
            if len(words) == words_count:
                test_name, unit_test_name, command_index = words
                return test_name, unit_test_name, int(command_index)
            return None

        jsonrpc = json.loads(request.content)
        assert jsonrpc["method"] == "runCmds"
        commands = jsonrpc["params"]["cmds"]
        ofmt = jsonrpc["params"]["format"]
        req_id: str = jsonrpc["id"]
        result = None

        # Extract the test name, unit test name, and command index from the request ID
        if (words := parse_req_id(req_id)) is not None:
            test_name, unit_test_name, idx = words

            # This should never happen, but better be safe than sorry
            if (test_name, unit_test_name) not in self.eos_data_catalog:
                msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: eos_data not found"
                raise RuntimeError(msg)

            eos_data = self.eos_data_catalog[(test_name, unit_test_name)]

            # This could happen if the unit test data is not correctly defined
            if idx >= len(eos_data):
                msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: missing test case in eos_data"
                raise RuntimeError(msg)
            result = {"output": eos_data[idx]} if ofmt == "text" else eos_data[idx]
        elif {"cmd": "show version"} in commands and ofmt == "json":
            # Mock 'show version' request performed during inventory refresh.
            result = {
                "modelName": "pytest",
            }

        if result is not None:
            return httpx.Response(
                status_code=200,
                json={
                    "jsonrpc": "2.0",
                    "id": req_id,
                    "result": [result],
                },
            )
        msg = f"The following eAPI Request has not been mocked: {jsonrpc}"
        raise NotImplementedError(msg)
