File: conftest.py

package info (click to toggle)
python-renault-api 0.4.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky
  • size: 1,680 kB
  • sloc: python: 7,679; makefile: 2
file content (123 lines) | stat: -rw-r--r-- 3,618 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Test configuration."""

import asyncio
import functools
import pathlib
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import tzinfo
from typing import Any
from typing import Optional

import pytest
import pytest_asyncio
from _pytest.monkeypatch import MonkeyPatch
from aiohttp.client import ClientSession
from aioresponses import aioresponses
from click.testing import CliRunner


@pytest_asyncio.fixture
async def websession() -> AsyncGenerator[ClientSession, None]:
    """Fixture for generating ClientSession."""
    async with ClientSession() as aiohttp_session:
        yield aiohttp_session

        closed_event = create_aiohttp_closed_event(aiohttp_session)
        await aiohttp_session.close()
        await closed_event.wait()


@pytest.fixture(autouse=True)
def mocked_responses() -> Generator[aioresponses, None, None]:
    """Fixture for mocking aiohttp responses."""
    with aioresponses() as m:
        yield m


@pytest.fixture
def cli_runner(monkeypatch: MonkeyPatch, tmpdir: pathlib.Path) -> CliRunner:
    """Fixture for invoking command-line interfaces."""
    runner = CliRunner()

    monkeypatch.setattr("os.path.expanduser", lambda x: x.replace("~", str(tmpdir)))

    class TZ1(tzinfo):
        def utcoffset(self, dt: Optional[datetime]) -> timedelta:
            return timedelta(hours=1)

        def dst(self, dt: Optional[datetime]) -> timedelta:
            return timedelta(0)

        def tzname(self, dt: Optional[datetime]) -> str:
            return "+01:00"

        def __repr__(self) -> str:
            return f"{self.__class__.__name__}()"

    def get_test_zone() -> Any:
        # Get a non UTC zone, avoiding DST on standard zones.
        return TZ1()

    monkeypatch.setattr("tzlocal.get_localzone", get_test_zone)

    return runner


def create_aiohttp_closed_event(
    session: ClientSession,
) -> asyncio.Event:
    """Work around aiohttp issue that doesn't properly close transports on exit.

    See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209

    Args:
        session (ClientSession): session for which to generate the event.

    Returns:
        An event that will be set once all transports have been properly closed.
    """
    transports = 0
    all_is_lost = asyncio.Event()

    def connection_lost(exc, orig_lost):  # type: ignore[no-untyped-def]
        nonlocal transports

        try:
            orig_lost(exc)
        finally:
            transports -= 1
            if transports == 0:
                all_is_lost.set()

    def eof_received(orig_eof_received):  # type: ignore[no-untyped-def]
        try:
            orig_eof_received()
        except AttributeError:
            # It may happen that eof_received() is called after
            # _app_protocol and _transport are set to None.
            pass

    for conn in session.connector._conns.values():  # type: ignore[union-attr]
        for handler, _ in conn:
            proto = getattr(handler.transport, "_ssl_protocol", None)
            if proto is None:
                continue

            transports += 1
            orig_lost = proto.connection_lost
            orig_eof_received = proto.eof_received

            proto.connection_lost = functools.partial(
                connection_lost, orig_lost=orig_lost
            )
            proto.eof_received = functools.partial(
                eof_received, orig_eof_received=orig_eof_received
            )

    if transports == 0:
        all_is_lost.set()

    return all_is_lost