1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
|
"""Multiplexer message queues."""
from __future__ import annotations
import asyncio
from collections import OrderedDict, deque
from collections.abc import Callable
import contextlib
from dataclasses import dataclass, field
import logging
from .message import HEADER_SIZE, MultiplexerChannelId, MultiplexerMessage
_LOGGER = logging.getLogger(__name__)
@dataclass(slots=True)
class _ChannelQueue:
"""Channel queue.
A queue that manages a single channel, with a size limit.
total_bytes: the size of the queue in bytes instead of the number of items.
queue: a deque of MultiplexerMessage | None.
putters: a deque of asyncio.Future[None] which is used to wake up putters
when the queue is full and space becomes available.
"""
under_water_callback: Callable[[bool], None]
total_bytes: int = 0
under_water: bool = False
pending_close: bool = False
queue: deque[MultiplexerMessage | None] = field(default_factory=deque)
putters: deque[asyncio.Future[None]] = field(default_factory=deque)
def _effective_size(message: MultiplexerMessage | None) -> int:
"""Return the effective size of the message."""
return 0 if message is None else HEADER_SIZE + len(message.data)
class MultiplexerSingleChannelQueue(asyncio.Queue[MultiplexerMessage | None]):
"""Multiplexer single channel queue.
qsize is the size of the queue in bytes instead of the number of items.
Note that the queue is allowed to go over by one message
because we are subclassing asyncio.Queue and it is not
possible to prevent this without reimplementing the whole
class, which is not worth it since its ok if we go over by
one message.
"""
_total_bytes: int = 0
def __init__(
self,
maxsize: int,
low_water_mark: int,
high_water_mark: int,
under_water_callback: Callable[[bool], None],
) -> None:
"""Initialize Multiplexer Queue."""
self._low_water_mark = low_water_mark
self._high_water_mark = high_water_mark
self._under_water_callback = under_water_callback
self._under_water: bool = False
super().__init__(maxsize)
def _put(self, message: MultiplexerMessage | None) -> None:
"""Put a message in the queue."""
self._total_bytes += _effective_size(message)
super()._put(message)
if not self._under_water and self._total_bytes >= self._high_water_mark:
self._under_water = True
self._under_water_callback(True)
def _get(self) -> MultiplexerMessage | None:
"""Get a message from the queue."""
message = super()._get()
self._total_bytes -= _effective_size(message)
if self._under_water and self._total_bytes <= self._low_water_mark:
self._under_water = False
self._under_water_callback(False)
return message
def qsize(self) -> int:
"""Size of the queue in bytes."""
return self._total_bytes
class MultiplexerMultiChannelQueue:
"""Multiplexer multi channel queue.
A queue that manages multiple channels, each with a size limit.
This class allows for asynchronous message passing between multiple channels,
ensuring that each channel does not exceed a specified size limit.
When fetching from the queue, the channels are fetched in a round-robin
fashion, ensuring that no channel is starved.
"""
def __init__(
self,
channel_size_limit: int,
channel_low_water_mark: int,
channel_high_water_mark: int,
) -> None:
"""Initialize Multiplexer Queue.
Args:
channel_size_limit (int): The maximum size of a channel
data queue in bytes.
"""
self._channel_size_limit = channel_size_limit
self._channel_low_water_mark = channel_low_water_mark
self._channel_high_water_mark = channel_high_water_mark
self._channels: dict[MultiplexerChannelId, _ChannelQueue] = {}
# _order controls which channel_id to get next. We use
# an OrderedDict because we need to use popitem(last=False)
# here to maintain FIFO order.
self._order: OrderedDict[MultiplexerChannelId, None] = OrderedDict()
self._getters: deque[asyncio.Future[None]] = deque()
self._loop = asyncio.get_running_loop()
def create_channel(
self,
channel_id: MultiplexerChannelId,
under_water_callback: Callable[[bool], None],
) -> None:
"""Create a new channel."""
_LOGGER.debug("Queue creating channel %s", channel_id)
if channel_id in self._channels:
raise RuntimeError(f"Channel {channel_id} already exists")
self._channels[channel_id] = _ChannelQueue(under_water_callback)
def delete_channel(self, channel_id: MultiplexerChannelId) -> None:
"""Delete a channel."""
if channel := self._channels.get(channel_id):
if channel.queue:
channel.pending_close = True
else:
del self._channels[channel_id]
def _wakeup_next(self, waiters: deque[asyncio.Future[None]]) -> None:
"""Wake up the next waiter."""
while waiters:
waiter = waiters.popleft()
if not waiter.done():
waiter.set_result(None)
break
async def put(
self,
channel_id: MultiplexerChannelId,
message: MultiplexerMessage | None,
) -> None:
"""Put a message in the queue."""
# Based on asyncio.Queue.put()
if not (channel := self._channels.get(channel_id)):
raise RuntimeError(f"Channel {channel_id} does not exist or already closed")
size = _effective_size(message)
while channel.total_bytes + size > self._channel_size_limit: # full
putter = self._loop.create_future()
channel.putters.append(putter)
try:
await putter
except:
putter.cancel() # Just in case putter is not done yet.
with contextlib.suppress(ValueError):
# Clean self._putters from canceled putters.
channel.putters.remove(putter)
if not self.full(channel_id) and not putter.cancelled():
# We were woken up by get_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(channel.putters)
raise
self._put(channel_id, channel, message, size)
def put_nowait(
self,
channel_id: MultiplexerChannelId,
message: MultiplexerMessage | None,
) -> None:
"""Put a message in the queue.
Raises:
asyncio.QueueFull: If the queue is full.
"""
size = _effective_size(message)
if not (channel := self._channels.get(channel_id)):
raise RuntimeError(f"Channel {channel_id} does not exist or already closed")
if channel.total_bytes + size > self._channel_size_limit:
raise asyncio.QueueFull
self._put(channel_id, channel, message, size)
def put_nowait_force(
self,
channel_id: MultiplexerChannelId,
message: MultiplexerMessage | None,
) -> None:
"""Put a message in the queue.
This method is used to force a message into the queue without
checking if the queue is full. This is used when a channel is
being closed.
"""
if not (channel := self._channels.get(channel_id)):
raise RuntimeError(f"Channel {channel_id} does not exist or already closed")
self._put(channel_id, channel, message, _effective_size(message))
def _put(
self,
channel_id: MultiplexerChannelId,
channel: _ChannelQueue,
message: MultiplexerMessage | None,
size: int,
) -> None:
"""Put a message in the queue."""
channel.queue.append(message)
channel.total_bytes += size
self._order[channel_id] = None
if (
not channel.under_water
and channel.total_bytes >= self._channel_high_water_mark
):
channel.under_water = True
channel.under_water_callback(True)
self._wakeup_next(self._getters)
async def get(self) -> MultiplexerMessage | None:
"""Asynchronously retrieve the next `MultiplexerMessage` from the queue."""
# Based on asyncio.Queue.get()
while not self._order: # order is which channel_id to get next
getter = self._loop.create_future()
self._getters.append(getter)
try:
await getter
except:
getter.cancel() # Just in case getter is not done yet.
with contextlib.suppress(ValueError):
# Clean self._getters from canceled getters.
self._getters.remove(getter)
# order is which channel_id to get next
if self._order and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(self._getters)
raise
return self.get_nowait()
def get_nowait(self) -> MultiplexerMessage | None:
"""Get a message from the queue.
Raises:
asyncio.QueueEmpty: If the queue is empty.
"""
if not self._order:
raise asyncio.QueueEmpty
channel_id, _ = self._order.popitem(last=False)
channel = self._channels[channel_id]
message = channel.queue.popleft()
size = _effective_size(message)
channel.total_bytes -= size
if channel.queue:
# Now put the channel_id back, but at the end of the queue
# so the next get will get the next waiting channel_id.
self._order[channel_id] = None
elif channel.pending_close:
# Got to the end of the queue and the channel wants
# to close so we now drop the channel.
del self._channels[channel_id]
if channel.under_water and channel.total_bytes <= self._channel_low_water_mark:
channel.under_water = False
channel.under_water_callback(False)
if channel.putters:
self._wakeup_next(channel.putters)
return message
def empty(self, channel_id: MultiplexerChannelId) -> bool:
"""Empty the queue."""
if not (channel := self._channels.get(channel_id)):
return True
return channel.total_bytes == 0
def size(self, channel_id: MultiplexerChannelId) -> int:
"""Return the size of the channel queue in bytes."""
if not (channel := self._channels.get(channel_id)):
return 0
return channel.total_bytes
def full(self, channel_id: MultiplexerChannelId) -> bool:
"""Return True if the channel queue is full."""
if not (channel := self._channels.get(channel_id)):
return False
return channel.total_bytes >= self._channel_size_limit
|