File: tools.py

package info (click to toggle)
python-aiormq 6.8.1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 392 kB
  • sloc: python: 3,214; makefile: 27
file content (118 lines) | stat: -rw-r--r-- 3,284 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import asyncio
import platform
import time
from functools import wraps
from types import TracebackType
from typing import (
    Any, AsyncContextManager, Awaitable, Callable, Coroutine, Optional, Type,
    TypeVar, Union,
)

from yarl import URL

from aiormq.abc import TimeoutType


T = TypeVar("T")


def censor_url(url: URL) -> URL:
    if url.password is not None:
        return url.with_password("******")
    return url


def shield(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
    @wraps(func)
    def wrap(*args: Any, **kwargs: Any) -> Awaitable[T]:
        return asyncio.shield(func(*args, **kwargs))

    return wrap


def awaitable(
    func: Callable[..., Union[T, Awaitable[T]]],
) -> Callable[..., Coroutine[Any, Any, T]]:
    # Avoid python 3.8+ warning
    if asyncio.iscoroutinefunction(func):
        return func     # type: ignore

    @wraps(func)
    async def wrap(*args: Any, **kwargs: Any) -> T:
        result = func(*args, **kwargs)

        if hasattr(result, "__await__"):
            return await result     # type: ignore
        if asyncio.iscoroutine(result) or asyncio.isfuture(result):
            return await result

        return result               # type: ignore

    return wrap


class Countdown:
    __slots__ = "loop", "deadline"

    if platform.system() == "Windows":
        @staticmethod
        def _now() -> float:
            # windows monotonic timer resolution is not enough.
            # Have to use time.time()
            return time.time()
    else:
        @staticmethod
        def _now() -> float:
            return time.monotonic()

    def __init__(self, timeout: TimeoutType = None):
        self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
        self.deadline: TimeoutType = None

        if timeout is not None:
            self.deadline = self._now() + timeout

    def get_timeout(self) -> TimeoutType:
        if self.deadline is None:
            return None

        current = self._now()
        if current >= self.deadline:
            raise asyncio.TimeoutError

        return self.deadline - current

    async def __call__(self, coro: Awaitable[T]) -> T:
        try:
            timeout = self.get_timeout()
        except asyncio.TimeoutError:
            fut = asyncio.ensure_future(coro)
            fut.cancel()
            await asyncio.gather(fut, return_exceptions=True)
            raise

        if self.deadline is None and not timeout:
            return await coro
        return await asyncio.wait_for(coro, timeout=timeout)

    def enter_context(
        self, ctx: AsyncContextManager[T],
    ) -> AsyncContextManager[T]:
        return CountdownContext(self, ctx)


class CountdownContext(AsyncContextManager):
    def __init__(self, countdown: Countdown, ctx: AsyncContextManager):
        self.countdown: Countdown = countdown
        self.ctx: AsyncContextManager = ctx

    async def __aenter__(self) -> T:
        return await self.countdown(self.ctx.__aenter__())

    async def __aexit__(
        self, exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException], exc_tb: Optional[TracebackType],
    ) -> Any:
        return await self.countdown(
            self.ctx.__aexit__(exc_type, exc_val, exc_tb),
        )