1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
"""Provide a context to easily manage several streamers running
concurrently.
"""
from __future__ import annotations
import asyncio
from .aiter_utils import AsyncExitStack
from .aiter_utils import anext
from .core import streamcontext
from typing import (
TYPE_CHECKING,
Awaitable,
List,
Set,
Tuple,
Generic,
TypeVar,
Any,
Type,
AsyncIterable,
)
from types import TracebackType
if TYPE_CHECKING:
from asyncio import Task
from aiostream.core import Streamer
T = TypeVar("T")
class TaskGroup:
def __init__(self) -> None:
self._pending: set[Task[Any]] = set()
async def __aenter__(self) -> TaskGroup:
return self
async def __aexit__(
self,
typ: Type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> None:
while self._pending:
task = self._pending.pop()
await self.cancel_task(task)
def create_task(self, coro: Awaitable[T]) -> Task[T]:
task = asyncio.ensure_future(coro)
self._pending.add(task)
return task
async def wait_any(self, tasks: List[Task[T]]) -> Set[Task[T]]:
done, _ = await asyncio.wait(tasks, return_when="FIRST_COMPLETED")
self._pending -= done
return done
async def wait_all(self, tasks: List[Task[T]]) -> Set[Task[T]]:
if not tasks:
return set()
done, _ = await asyncio.wait(tasks)
self._pending -= done
return done
async def cancel_task(self, task: Task[Any]) -> None:
try:
# The task is already cancelled
if task.cancelled():
pass
# The task is already finished
elif task.done():
# Discard the pending exception (if any).
# This makes sense since we don't know in which context the exception
# was meant to be processed. For instance, a `StopAsyncIteration`
# might be raised to notify that the end of a streamer has been reached.
task.exception()
# The task needs to be cancelled and awaited
else:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Silence any exception raised while cancelling the task.
# This might happen if the `CancelledError` is silenced, and the
# corresponding async generator returns, causing the `anext` call
# to raise a `StopAsyncIteration`.
except Exception:
pass
finally:
self._pending.discard(task)
class StreamerManager(Generic[T]):
def __init__(self) -> None:
self.tasks: dict[Streamer[T], Task[T]] = {}
self.streamers: list[Streamer[T]] = []
self.group: TaskGroup = TaskGroup()
self.stack = AsyncExitStack()
async def __aenter__(self) -> StreamerManager[T]:
await self.stack.__aenter__()
await self.stack.enter_async_context(self.group)
return self
async def __aexit__(
self,
typ: Type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> bool:
for streamer in self.streamers:
task = self.tasks.pop(streamer, None)
if task is not None:
self.stack.push_async_callback(self.group.cancel_task, task)
self.stack.push_async_exit(streamer)
self.tasks.clear()
self.streamers.clear()
return await self.stack.__aexit__(typ, value, traceback)
async def enter_and_create_task(self, aiter: AsyncIterable[T]) -> Streamer[T]:
streamer = streamcontext(aiter)
await streamer.__aenter__()
self.streamers.append(streamer)
self.create_task(streamer)
return streamer
def create_task(self, streamer: Streamer[T]) -> None:
assert streamer in self.streamers
assert streamer not in self.tasks
self.tasks[streamer] = self.group.create_task(anext(streamer))
async def wait_single_event(
self, filters: list[Streamer[T]]
) -> Tuple[Streamer[T], Task[T]]:
tasks = [self.tasks[streamer] for streamer in filters]
done = await self.group.wait_any(tasks)
for streamer in filters:
if self.tasks.get(streamer) in done:
return streamer, self.tasks.pop(streamer)
assert False
async def clean_streamer(self, streamer: Streamer[T]) -> None:
task = self.tasks.pop(streamer, None)
if task is not None:
await self.group.cancel_task(task)
await streamer.aclose()
self.streamers.remove(streamer)
async def clean_streamers(self, streamers: list[Streamer[T]]) -> None:
tasks = [
self.group.create_task(self.clean_streamer(streamer))
for streamer in streamers
]
done = await self.group.wait_all(tasks)
# Raise exception if any
for task in done:
task.result()
|