1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
from __future__ import annotations
import asyncio
import threading
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
Coro = Coroutine[Any, Any, T]
class _LoopBoundMixin:
"""Backport of the private asyncio.mixins._LoopBoundMixin from 3.11"""
_global_lock = threading.Lock()
_loop = None
def _get_loop(self):
loop = asyncio.get_running_loop()
if self._loop is None:
with self._global_lock:
if self._loop is None:
self._loop = loop
if loop is not self._loop:
raise RuntimeError(f"{self!r} is bound to a different event loop")
return loop
class AsyncTaskGroupClosedError(RuntimeError):
pass
def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]:
"""Decorator to delay the evaluation of a coroutine function by the given delay in seconds."""
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
await asyncio.sleep(delay)
return await corofunc(*args, **kwargs)
return wrapper
class AsyncTaskGroup(_LoopBoundMixin):
"""Collection tracking all currently running asynchronous tasks within a group"""
#: If True, the group is closed and does not allow adding new tasks.
closed: bool
def __init__(self) -> None:
self.closed = False
self._ongoing_tasks: set[asyncio.Task[None]] = set()
def call_soon(
self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs
) -> None:
"""Schedule a coroutine function to be executed as an `asyncio.Task`.
The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments
as an `asyncio.Task`.
Parameters
----------
afunc
Coroutine function to schedule.
*args
Arguments to be passed to `afunc`.
**kwargs
Keyword arguments to be passed to `afunc`
Returns
-------
None
Raises
------
AsyncTaskGroupClosedError
If the task group is closed.
"""
if self.closed: # Avoid creating a coroutine
raise AsyncTaskGroupClosedError(
"Cannot schedule a new coroutine function as the group is already closed."
)
task = self._get_loop().create_task(afunc(*args, **kwargs))
task.add_done_callback(self._ongoing_tasks.remove)
self._ongoing_tasks.add(task)
return None
def call_later(
self,
delay: float,
afunc: Callable[P, Coro[None]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`.
The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments
as an `asyncio.Task` that is executed after `delay` seconds.
Parameters
----------
delay
Delay in seconds.
afunc
Coroutine function to schedule.
*args
Arguments to be passed to `afunc`.
**kwargs
Keyword arguments to be passed to `afunc`
Returns
-------
The None
Raises
------
AsyncTaskGroupClosedError
If the task group is closed.
"""
self.call_soon(_delayed(afunc, delay), *args, **kwargs)
def close(self) -> None:
"""Closes the task group so that no new tasks can be scheduled.
Existing tasks continue to run.
"""
self.closed = True
async def stop(self) -> None:
"""Close the group and stop all currently running tasks.
Closes the task group and cancels all tasks. All tasks are cancelled
an additional time for each time this task is cancelled.
"""
self.close()
current_task = asyncio.current_task(self._get_loop())
err = None
while tasks_to_stop := (self._ongoing_tasks - {current_task}):
for task in tasks_to_stop:
task.cancel()
try:
await asyncio.wait(tasks_to_stop)
except asyncio.CancelledError as e:
err = e
if err is not None:
raise err
def __len__(self):
return len(self._ongoing_tasks)
|