1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
|
"""Test configuration."""
import asyncio
import functools
import pathlib
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import tzinfo
from typing import Any
from typing import Optional
import pytest
import pytest_asyncio
from _pytest.monkeypatch import MonkeyPatch
from aiohttp.client import ClientSession
from aioresponses import aioresponses
from click.testing import CliRunner
@pytest_asyncio.fixture
async def websession() -> AsyncGenerator[ClientSession, None]:
"""Fixture for generating ClientSession."""
async with ClientSession() as aiohttp_session:
yield aiohttp_session
closed_event = create_aiohttp_closed_event(aiohttp_session)
await aiohttp_session.close()
await closed_event.wait()
@pytest.fixture(autouse=True)
def mocked_responses() -> Generator[aioresponses, None, None]:
"""Fixture for mocking aiohttp responses."""
with aioresponses() as m:
yield m
@pytest.fixture
def cli_runner(monkeypatch: MonkeyPatch, tmpdir: pathlib.Path) -> CliRunner:
"""Fixture for invoking command-line interfaces."""
runner = CliRunner()
monkeypatch.setattr("os.path.expanduser", lambda x: x.replace("~", str(tmpdir)))
class TZ1(tzinfo):
def utcoffset(self, dt: Optional[datetime]) -> timedelta:
return timedelta(hours=1)
def dst(self, dt: Optional[datetime]) -> timedelta:
return timedelta(0)
def tzname(self, dt: Optional[datetime]) -> str:
return "+01:00"
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def get_test_zone() -> Any:
# Get a non UTC zone, avoiding DST on standard zones.
return TZ1()
monkeypatch.setattr("tzlocal.get_localzone", get_test_zone)
return runner
def create_aiohttp_closed_event(
session: ClientSession,
) -> asyncio.Event:
"""Work around aiohttp issue that doesn't properly close transports on exit.
See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209
Args:
session (ClientSession): session for which to generate the event.
Returns:
An event that will be set once all transports have been properly closed.
"""
transports = 0
all_is_lost = asyncio.Event()
def connection_lost(exc, orig_lost): # type: ignore[no-untyped-def]
nonlocal transports
try:
orig_lost(exc)
finally:
transports -= 1
if transports == 0:
all_is_lost.set()
def eof_received(orig_eof_received): # type: ignore[no-untyped-def]
try:
orig_eof_received()
except AttributeError:
# It may happen that eof_received() is called after
# _app_protocol and _transport are set to None.
pass
for conn in session.connector._conns.values(): # type: ignore[union-attr]
for handler, _ in conn:
proto = getattr(handler.transport, "_ssl_protocol", None)
if proto is None:
continue
transports += 1
orig_lost = proto.connection_lost
orig_eof_received = proto.eof_received
proto.connection_lost = functools.partial(
connection_lost, orig_lost=orig_lost
)
proto.eof_received = functools.partial(
eof_received, orig_eof_received=orig_eof_received
)
if transports == 0:
all_is_lost.set()
return all_is_lost
|