File: base.py

package info (click to toggle)
python-aiormq 6.8.1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 392 kB
  • sloc: python: 3,214; makefile: 27
file content (166 lines) | stat: -rw-r--r-- 4,800 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import abc
import asyncio
from contextlib import suppress
from functools import wraps
from typing import Any, Callable, Coroutine, Optional, Set, TypeVar, Union
from weakref import WeakSet

from .abc import (
    AbstractBase, AbstractFutureStore, CoroutineType, ExceptionType, TaskType,
    TaskWrapper, TimeoutType,
)
from .tools import Countdown, shield


T = TypeVar("T")


class FutureStore(AbstractFutureStore):
    __slots__ = "futures", "loop", "parent"

    futures: Set[Union[asyncio.Future, TaskType]]
    weak_futures: WeakSet
    loop: asyncio.AbstractEventLoop

    def __init__(self, loop: asyncio.AbstractEventLoop):
        self.futures = set()
        self.loop = loop
        self.parent: Optional[FutureStore] = None

    def __on_task_done(
        self, future: Union[asyncio.Future, TaskWrapper],
    ) -> Callable[..., Any]:
        def remover(*_: Any) -> None:
            nonlocal future
            if future in self.futures:
                self.futures.remove(future)

        return remover

    def add(self, future: Union[asyncio.Future, TaskWrapper]) -> None:
        self.futures.add(future)
        future.add_done_callback(self.__on_task_done(future))

        if self.parent:
            self.parent.add(future)

    @shield
    async def reject_all(self, exception: Optional[ExceptionType]) -> None:
        tasks = []

        while self.futures:
            future: Union[TaskType, asyncio.Future] = self.futures.pop()

            if future.done():
                continue

            if isinstance(future, TaskWrapper):
                future.throw(exception or Exception)
                tasks.append(future)
            elif isinstance(future, asyncio.Future):
                future.set_exception(exception or Exception)

        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)

    def create_task(self, coro: CoroutineType) -> TaskType:
        task: TaskWrapper = TaskWrapper(self.loop.create_task(coro))
        self.add(task)
        return task

    def create_future(self, weak: bool = False) -> asyncio.Future:
        future = self.loop.create_future()
        self.add(future)
        return future

    def get_child(self) -> "FutureStore":
        store = FutureStore(self.loop)
        store.parent = self
        return store


class Base(AbstractBase):
    __slots__ = "loop", "__future_store", "closing"

    def __init__(
        self, *, loop: asyncio.AbstractEventLoop,
        parent: Optional[AbstractBase] = None,
    ):
        self.loop: asyncio.AbstractEventLoop = loop

        if parent:
            self.__future_store = parent._future_store_child()
        else:
            self.__future_store = FutureStore(loop=self.loop)

        self.closing = self._create_closing_future()

    def _create_closing_future(self) -> asyncio.Future:
        future = self.__future_store.create_future()
        future.add_done_callback(lambda x: x.exception())
        return future

    def _cancel_tasks(
        self, exc: Optional[ExceptionType] = None,
    ) -> Coroutine[Any, Any, None]:
        return self.__future_store.reject_all(exc)

    def _future_store_child(self) -> AbstractFutureStore:
        return self.__future_store.get_child()

    def create_task(self, coro: CoroutineType) -> TaskType:
        return self.__future_store.create_task(coro)

    def create_future(self) -> asyncio.Future:
        return self.__future_store.create_future()

    @abc.abstractmethod
    async def _on_close(
        self, exc: Optional[ExceptionType] = None,
    ) -> None:  # pragma: no cover
        return

    async def __closer(self, exc: Optional[ExceptionType]) -> None:
        if self.is_closed:  # pragma: no cover
            return

        with suppress(Exception):
            await self._on_close(exc)

        with suppress(Exception):
            await self._cancel_tasks(exc)

    async def close(
        self, exc: Optional[ExceptionType] = asyncio.CancelledError,
        timeout: TimeoutType = None,
    ) -> None:
        if self.is_closed:
            return None

        countdown = Countdown(timeout)
        await countdown(self.__closer(exc))

    def __repr__(self) -> str:
        cls_name = self.__class__.__name__
        return '<{0}: "{1}" at 0x{2:02x}>'.format(
            cls_name, str(self), id(self),
        )

    @abc.abstractmethod
    def __str__(self) -> str:  # pragma: no cover
        raise NotImplementedError

    @property
    def is_closed(self) -> bool:
        return self.closing.done()


TaskFunctionType = Callable[..., T]


def task(func: TaskFunctionType) -> TaskFunctionType:
    @wraps(func)
    async def wrap(self: Base, *args: Any, **kwargs: Any) -> Any:
        return await self.create_task(func(self, *args, **kwargs))

    return wrap