1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
"""Common helpers for tests."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from aiohttp import WSMsgType
class WSMessage:
"""WSMessage."""
def __init__(self, messagetype: WSMsgType, json: dict | None = None) -> None:
"""Initialize."""
self.type = messagetype
self._json = json
def json(self) -> dict | None:
"""json."""
return self._json
def load_response(filename: str) -> dict[str, Any]:
"""Load a response."""
filename = f"{filename.split('?')[0]}.json" if "." not in filename else filename
path = Path(
Path.resolve(Path(__file__)).parent,
"responses",
filename.lower().replace("/", "_"),
)
with path.open(encoding="utf-8") as fptr:
return json.loads(fptr.read())
class WSMessageHandler:
"""WSMessageHandler."""
def __init__(self) -> None:
"""Initialize."""
self.messages = []
def add(self, msg: WSMessage) -> None:
"""Add."""
self.messages.append(msg)
def get(self) -> WSMessage:
"""Get."""
return (
self.messages.pop(0)
if self.messages
else WSMessage(messagetype=WSMsgType.CLOSED)
)
@dataclass
class MockResponse:
"""Mock response class."""
_count = 0
mock_data: Any | None = None
mock_data_list: list[Any] | None = None
mock_endpoint: str = ""
mock_headers: dict[str, str] | None = None
mock_raises: BaseException | None = None
mock_status: int = 200
_in_context: bool = False
async def __aenter__(self) -> MockResponse:
self._in_context = True
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
self._in_context = False
self.release()
await self.wait_for_close()
@property
def status(self) -> int:
"""status."""
return self.mock_status
@property
def reason(self) -> str:
"""Return the reason."""
return "unknown"
async def json(self, **_: Any) -> Any:
"""json."""
if self.mock_raises is not None:
raise self.mock_raises # pylint: disable=raising-bad-type
if self.mock_data_list:
data = self.mock_data_list[self._count]
self._count += 1
return data
if self.mock_data is not None:
return self.mock_data
return load_response(self.mock_endpoint)
def release(self) -> None:
"""release."""
def clear(self) -> None:
"""clear."""
self.mock_data = None
self.mock_endpoint = ""
self.mock_headers = None
self.mock_raises = None
self.mock_status = 200
async def wait_for_close(self) -> None:
"""wait_for_close."""
class MockedRequests:
"""Mock request class."""
def __init__(self) -> None:
"""Initialize."""
self._calls = []
def add(self, url: str) -> None:
"""add."""
self._calls.append(url)
def clear(self) -> None:
"""clear."""
self._calls.clear()
def __repr__(self) -> str:
"""repr."""
return f"<MockedRequests: {self._calls}>"
@property
def called(self) -> int:
"""count."""
return len(self._calls)
def has(self, string: str) -> bool:
"""has."""
return bool([entry for entry in self._calls if string in entry])
@property
def last_request(self) -> MockResponse:
"""Last url."""
return self._calls[-1]
|