1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
|
"""Advanced operators (to deal with streams of higher order) ."""
from __future__ import annotations
from typing import AsyncIterator, AsyncIterable, TypeVar, Union, cast
from typing_extensions import ParamSpec
from . import combine
from ..core import Streamer, pipable_operator
from ..manager import StreamerManager
__all__ = ["concat", "flatten", "switch", "concatmap", "flatmap", "switchmap"]
T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")
# Helper to manage stream of higher order
@pipable_operator
async def base_combine(
source: AsyncIterable[AsyncIterable[T]],
switch: bool = False,
ordered: bool = False,
task_limit: int | None = None,
) -> AsyncIterator[T]:
"""Base operator for managing an asynchronous sequence of sequences.
The sequences are awaited concurrently, although it's possible to limit
the amount of running sequences using the `task_limit` argument.
The ``switch`` argument enables the switch mecanism, which cause the
previous subsequence to be discarded when a new one is created.
The items can either be generated in order or as soon as they're received,
depending on the ``ordered`` argument.
"""
# Task limit
if task_limit is not None and not task_limit > 0:
raise ValueError("The task limit must be None or greater than 0")
# Safe context
async with StreamerManager[Union[AsyncIterable[T], T]]() as manager:
main_streamer: Streamer[
AsyncIterable[T] | T
] | None = await manager.enter_and_create_task(source)
# Loop over events
while manager.tasks:
# Extract streamer groups
substreamers = manager.streamers[1:]
mainstreamers = [main_streamer] if main_streamer in manager.tasks else []
# Switch - use the main streamer then the substreamer
if switch:
filters = mainstreamers + substreamers
# Concat - use the first substreamer then the main streamer
elif ordered:
filters = substreamers[:1] + mainstreamers
# Flat - use the substreamers then the main streamer
else:
filters = substreamers + mainstreamers
# Wait for next event
streamer, task = await manager.wait_single_event(filters)
# Get result
try:
result = task.result()
# End of stream
except StopAsyncIteration:
# Main streamer is finished
if streamer is main_streamer:
main_streamer = None
# A substreamer is finished
else:
await manager.clean_streamer(streamer)
# Re-schedule the main streamer if necessary
if main_streamer is not None and main_streamer not in manager.tasks:
manager.create_task(main_streamer)
# Process result
else:
# Switch mecanism
if switch and streamer is main_streamer:
await manager.clean_streamers(substreamers)
# Setup a new source
if streamer is main_streamer:
assert isinstance(result, AsyncIterable)
await manager.enter_and_create_task(result)
# Re-schedule the main streamer if task limit allows it
if task_limit is None or task_limit > len(manager.tasks):
manager.create_task(streamer)
# Yield the result
else:
item = cast("T", result)
yield item
# Re-schedule the streamer
manager.create_task(streamer)
# Advanced operators (for streams of higher order)
@pipable_operator
def concat(
source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
"""Given an asynchronous sequence of sequences, generate the elements
of the sequences in order.
The sequences are awaited concurrently, although it's possible to limit
the amount of running sequences using the `task_limit` argument.
Errors raised in the source or an element sequence are propagated.
"""
return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=True)
@pipable_operator
def flatten(
source: AsyncIterable[AsyncIterable[T]], task_limit: int | None = None
) -> AsyncIterator[T]:
"""Given an asynchronous sequence of sequences, generate the elements
of the sequences as soon as they're received.
The sequences are awaited concurrently, although it's possible to limit
the amount of running sequences using the `task_limit` argument.
Errors raised in the source or an element sequence are propagated.
"""
return base_combine.raw(source, task_limit=task_limit, switch=False, ordered=False)
@pipable_operator
def switch(source: AsyncIterable[AsyncIterable[T]]) -> AsyncIterator[T]:
"""Given an asynchronous sequence of sequences, generate the elements of
the most recent sequence.
Element sequences are generated eagerly, and closed once they are
superseded by a more recent sequence. Once the main sequence is finished,
the last subsequence will be exhausted completely.
Errors raised in the source or an element sequence (that was not already
closed) are propagated.
"""
return base_combine.raw(source, switch=True)
# Advanced *-map operators
@pipable_operator
def concatmap(
source: AsyncIterable[T],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
task_limit: int | None = None,
) -> AsyncIterator[U]:
"""Apply a given function that creates a sequence from the elements of one
or several asynchronous sequences, and generate the elements of the created
sequences in order.
The function is applied as described in `map`, and must return an
asynchronous sequence. The returned sequences are awaited concurrently,
although it's possible to limit the amount of running sequences using
the `task_limit` argument.
"""
mapped = combine.smap.raw(source, func, *more_sources)
return concat.raw(mapped, task_limit=task_limit)
@pipable_operator
def flatmap(
source: AsyncIterable[T],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
task_limit: int | None = None,
) -> AsyncIterator[U]:
"""Apply a given function that creates a sequence from the elements of one
or several asynchronous sequences, and generate the elements of the created
sequences as soon as they arrive.
The function is applied as described in `map`, and must return an
asynchronous sequence. The returned sequences are awaited concurrently,
although it's possible to limit the amount of running sequences using
the `task_limit` argument.
Errors raised in a source or output sequence are propagated.
"""
mapped = combine.smap.raw(source, func, *more_sources)
return flatten.raw(mapped, task_limit=task_limit)
@pipable_operator
def switchmap(
source: AsyncIterable[T],
func: combine.SmapCallable[T, AsyncIterable[U]],
*more_sources: AsyncIterable[T],
) -> AsyncIterator[U]:
"""Apply a given function that creates a sequence from the elements of one
or several asynchronous sequences and generate the elements of the most
recently created sequence.
The function is applied as described in `map`, and must return an
asynchronous sequence. Errors raised in a source or output sequence (that
was not already closed) are propagated.
"""
mapped = combine.smap.raw(source, func, *more_sources)
return switch.raw(mapped)
|