1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
|
# Copyright (c) 2023-2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Utils for the ANTA benchmark tests."""
from __future__ import annotations
import asyncio
import copy
import importlib
import json
import pkgutil
from typing import TYPE_CHECKING, Any
import httpx
from anta.catalog import AntaCatalog, AntaTestDefinition
from anta.models import AntaCommand, AntaTest
if TYPE_CHECKING:
from collections.abc import Generator
from types import ModuleType
from anta.device import AntaDevice
async def collect(self: AntaTest) -> None:
"""Patched anta.models.AntaTest.collect() method.
When generating the catalog, we inject a unit test case name in the custom_field input to be able to retrieve the eos_data for this specific test.
We use this unit test case name in the eAPI request ID.
"""
if self.inputs.result_overwrite is None or self.inputs.result_overwrite.custom_field is None:
msg = f"The custom_field input is not present for test {self.name}"
raise RuntimeError(msg)
await self.device.collect_commands(self.instance_commands, collection_id=f"{self.name}:{self.inputs.result_overwrite.custom_field}")
async def collect_commands(self: AntaDevice, commands: list[AntaCommand], collection_id: str) -> None:
"""Patched anta.device.AntaDevice.collect_commands() method.
For the same reason as above, we inject the command index of the test to the eAPI request ID.
"""
await asyncio.gather(*(self.collect(command=command, collection_id=f"{collection_id}:{idx}") for idx, command in enumerate(commands)))
class AntaMockEnvironment: # pylint: disable=too-few-public-methods
"""Generate an ANTA test catalog from the unit tests data. It can be accessed using the `catalog` attribute of this class instance.
Also provide the attribute 'eos_data_catalog` with the output of all the commands used in the test catalog.
Each module in `tests.units.anta_tests` has a `DATA` constant.
The `DATA` structure is a dictionary where:
- Each key is a tuple of size 2 containing:
- An AntaTest subclass imported in the test module as first element - e.g. VerifyUptime.
- A string used as name displayed by pytest as second element.
- Each value is an instance of AntaUnitTest, which is a Python TypedDict.
And AntaUnitTest have the following keys:
- `eos_data` (list[dict]): List of data mocking EOS returned data to be passed to the test.
- `inputs` (dict): Dictionary to instantiate the `test` inputs as defined in the class from `test`.
- `expected` (dict): Expected test result structure, a dictionary containing a key `result` containing one of the allowed status
(`Literal[AntaTestStatus.SUCCESS, AntaTestStatus.FAILURE, AntaTestStatus.SKIPPED]`) and
optionally a key `messages` which is a list(str) and each message is expected to be a substring of one of the actual messages in the TestResult object.
The keys of `eos_data_catalog` is the tuple (AntaTest subclass, A string used as name displayed by pytest). The values are `eos_data`.
"""
def __init__(self) -> None:
self._catalog, self.eos_data_catalog = self._generate_catalog()
self.tests_count = len(self._catalog.tests)
@property
def catalog(self) -> AntaCatalog:
"""AntaMockEnvironment object will always return a new AntaCatalog object based on the initial parsing.
This is because AntaCatalog objects store indexes when tests are run and we want a new object each time a test is run.
"""
return copy.deepcopy(self._catalog)
def _generate_catalog(self) -> tuple[AntaCatalog, dict[tuple[str, str], list[dict[str, Any]]]]:
"""Generate the `catalog` and `eos_data_catalog` attributes."""
def import_test_modules() -> Generator[ModuleType, None, None]:
"""Yield all test modules from the given package."""
package = importlib.import_module("tests.units.anta_tests")
prefix = package.__name__ + "."
for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix):
if not is_pkg and module_name.split(".")[-1].startswith("test_"):
module = importlib.import_module(module_name)
if hasattr(module, "DATA"):
yield module
test_definitions = []
eos_data_catalog = {}
for module in import_test_modules():
for (test, name), test_data in module.DATA.items():
# Extract the test class, name and test data from a nested tuple structure:
# unit test: Tuple[Tuple[Type[AntaTest], str], AntaUnitTest]
result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=name)
if test_data.get("inputs") is None:
inputs = test.Input(result_overwrite=result_overwrite)
else:
inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite)
test_definition = AntaTestDefinition(
test=test,
inputs=inputs,
)
eos_data_catalog[(test.__name__, name)] = test_data["eos_data"]
test_definitions.append(test_definition)
return (AntaCatalog(tests=test_definitions), eos_data_catalog)
def eapi_response(self, request: httpx.Request) -> httpx.Response:
"""Mock eAPI response.
If the eAPI request ID has the format `ANTA-{test name}:{unit test name}:{command index}-{command ID}`,
the function will return the eos_data from the unit test case.
Otherwise, it will mock 'show version' command or raise an Exception.
"""
words_count = 3
def parse_req_id(req_id: str) -> tuple[str, str, int] | None:
"""Parse the patched request ID from the eAPI request."""
req_id = req_id.removeprefix("ANTA-").rpartition("-")[0]
words = req_id.split(":", words_count)
if len(words) == words_count:
test_name, unit_test_name, command_index = words
return test_name, unit_test_name, int(command_index)
return None
jsonrpc = json.loads(request.content)
assert jsonrpc["method"] == "runCmds"
commands = jsonrpc["params"]["cmds"]
ofmt = jsonrpc["params"]["format"]
req_id: str = jsonrpc["id"]
result = None
# Extract the test name, unit test name, and command index from the request ID
if (words := parse_req_id(req_id)) is not None:
test_name, unit_test_name, idx = words
# This should never happen, but better be safe than sorry
if (test_name, unit_test_name) not in self.eos_data_catalog:
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: eos_data not found"
raise RuntimeError(msg)
eos_data = self.eos_data_catalog[(test_name, unit_test_name)]
# This could happen if the unit test data is not correctly defined
if idx >= len(eos_data):
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: missing test case in eos_data"
raise RuntimeError(msg)
result = {"output": eos_data[idx]} if ofmt == "text" else eos_data[idx]
elif {"cmd": "show version"} in commands and ofmt == "json":
# Mock 'show version' request performed during inventory refresh.
result = {
"modelName": "pytest",
}
if result is not None:
return httpx.Response(
status_code=200,
json={
"jsonrpc": "2.0",
"id": req_id,
"result": [result],
},
)
msg = f"The following eAPI Request has not been mocked: {jsonrpc}"
raise NotImplementedError(msg)
|