File: test_utils.py

package info (click to toggle)
python-aiostream 0.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 416 kB
  • sloc: python: 2,800; makefile: 5
file content (221 lines) | stat: -rw-r--r-- 6,104 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""Utilities for testing stream operators."""

from __future__ import annotations

import asyncio
from collections import deque
from contextlib import contextmanager
from unittest.mock import Mock
from typing import (
    TYPE_CHECKING,
    Awaitable,
    Callable,
    List,
    Protocol,
    TypeVar,
    AsyncIterable,
    AsyncIterator,
    ContextManager,
    Iterator,
    cast,
    Any,
)

import pytest

from .core import StreamEmpty, streamcontext, pipable_operator

if TYPE_CHECKING:
    from _pytest.fixtures import SubRequest
    from aiostream.core import Stream

__all__ = ["add_resource", "assert_run", "event_loop_policy", "assert_cleanup"]


T = TypeVar("T")


@pipable_operator
async def add_resource(
    source: AsyncIterable[T], cleanup_time: float
) -> AsyncIterator[T]:
    """Simulate an open resource in a stream operator."""
    loop = asyncio.get_running_loop()
    assert isinstance(loop, TimeTrackingTestLoop)
    try:
        loop.open_resources += 1
        loop.resources += 1
        async with streamcontext(source) as streamer:
            async for item in streamer:
                yield item
    finally:
        try:
            await asyncio.sleep(cleanup_time)
        finally:
            loop.open_resources -= 1


def compare_exceptions(
    exc1: Exception,
    exc2: Exception,
) -> bool:
    """Compare two exceptions together."""
    return exc1 == exc2 or exc1.__class__ == exc2.__class__ and exc1.args == exc2.args


async def assert_aiter(
    source: Stream[object],
    values: List[object],
    exception: Exception | None = None,
) -> None:
    """Check the results of a stream using a streamcontext."""
    results: list[object] = []
    exception_type = (type(exception),) if exception else ()
    try:
        async with streamcontext(source) as streamer:
            async for item in streamer:
                results.append(item)
    except exception_type as exc:
        assert exception is not None
        assert compare_exceptions(exc, exception)
    else:
        assert exception is None
    assert results == values


async def assert_await(
    source: Stream[object],
    values: List[object],
    exception: Exception | None = None,
) -> None:
    """Check the results of a stream using by awaiting it."""
    exception_type = (type(exception),) if exception else ()
    try:
        result = await source
    except StreamEmpty:
        assert values == []
        assert exception is None
    except exception_type as exc:
        assert exception is not None
        assert compare_exceptions(exc, exception)
    else:
        assert result == values[-1]
        assert exception is None


class AssertRunProtocol(Protocol):
    def __call__(
        self, source: Stream[object], values: List[object], exception: Exception | None
    ) -> Awaitable[None]: ...


@pytest.fixture(params=[assert_aiter, assert_await], ids=["aiter", "await"])  # type: ignore[misc]
def assert_run(request: SubRequest) -> AssertRunProtocol:
    """Parametrized fixture returning a stream runner."""
    return cast(AssertRunProtocol, request.param)


@pytest.fixture  # type: ignore[misc]
def event_loop_policy() -> TimeTrackingTestLoopPolicy:
    """Fixture providing a test event loop.

    The event loop simulate and records the sleep operation,
    available as event_loop.steps

    It also tracks simulated resources and make sure they are
    all released before the loop is closed.
    """
    return TimeTrackingTestLoopPolicy()


@pytest.fixture  # type: ignore[misc]
def assert_cleanup() -> Callable[[], ContextManager[TimeTrackingTestLoop]]:
    """Fixture to assert cleanup of resources."""

    def _assert_cleanup() -> ContextManager[TimeTrackingTestLoop]:
        loop = asyncio.get_running_loop()
        assert isinstance(loop, TimeTrackingTestLoop)
        return loop.assert_cleanup()

    return _assert_cleanup


class BaseEventLoopWithInternals(asyncio.BaseEventLoop):
    _ready: deque[asyncio.Handle]
    _run_once: Callable[[], None]


class TimeTrackingTestLoop(BaseEventLoopWithInternals):
    stuck_threshold: int = 100

    def __init__(self) -> None:
        super().__init__()
        self._time: float = 0.0
        self._timers: list[float] = []
        self._selector = Mock()

        self.steps: list[float] = []
        self.open_resources: int = 0
        self.resources: int = 0
        self.busy_count: int = 0

    # Loop internals

    def _run_once(self) -> None:  # type: ignore
        super()._run_once()
        # Update internals
        self.busy_count += 1
        self._timers = sorted(when for when in self._timers if when > self.time())
        # Time advance
        if self.time_to_go:
            when = self._timers.pop(0)
            step = when - self.time()
            self.steps.append(step)
            self.advance_time(step)
            self.busy_count = 0

    def _process_events(self, event_list: object) -> None:
        return

    def _write_to_self(self) -> None:
        return

    # Time management

    def time(self) -> float:
        return self._time

    def advance_time(self, advance: float) -> None:
        if advance:
            self._time += advance

    def call_at(self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any) -> asyncio.TimerHandle:  # type: ignore
        self._timers.append(when)
        return super().call_at(when, callback, *args, **kwargs)

    @property
    def stuck(self) -> bool:
        return self.busy_count > self.stuck_threshold

    @property
    def time_to_go(self) -> bool:
        return bool(self._timers) and (self.stuck or not self._ready)

    # Resource management

    def clear(self) -> None:
        self.steps = []
        self.open_resources = 0
        self.resources = 0
        self.busy_count = 0

    @contextmanager
    def assert_cleanup(self) -> Iterator[TimeTrackingTestLoop]:
        self.clear()
        yield self
        assert self.open_resources == 0
        self.clear()


class TimeTrackingTestLoopPolicy(asyncio.DefaultEventLoopPolicy):
    _loop_factory = TimeTrackingTestLoop