File: test_queue.py

package info (click to toggle)
python-snitun 0.45.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 640 kB
  • sloc: python: 6,681; sh: 5; makefile: 3
file content (604 lines) | stat: -rw-r--r-- 23,677 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
"""Test Multiplexer queue."""

import asyncio
import os

import pytest

from snitun.multiplexer.message import (
    CHANNEL_FLOW_DATA,
    HEADER_SIZE,
    MultiplexerChannelId,
    MultiplexerMessage,
)
from snitun.multiplexer.queue import (
    MultiplexerMultiChannelQueue,
    MultiplexerSingleChannelQueue,
)

MOCK_MSG_SIZE = 4


def _make_mock_channel_id() -> MultiplexerChannelId:
    return MultiplexerChannelId(os.urandom(16))


def _make_mock_message(
    channel_id: MultiplexerChannelId,
    size: int = MOCK_MSG_SIZE,
) -> MultiplexerMessage:
    return MultiplexerMessage(channel_id, CHANNEL_FLOW_DATA, os.urandom(size))


async def test_get_non_existent_channels() -> None:
    """Test MultiplexerMultiChannelQueue get on non-existent channel."""
    queue = MultiplexerMultiChannelQueue(100000, 10, 1000)
    assert queue.empty(_make_mock_channel_id())
    assert not queue.full(_make_mock_channel_id())
    assert queue.size(_make_mock_channel_id()) == 0
    # Make sure defaultdict does not leak
    assert not queue._channels


async def test_single_channel_queue() -> None:
    """Test MultiplexerSingleChannelQueue."""
    queue = MultiplexerSingleChannelQueue(100, 10, 50, lambda _: None)
    channel_id = _make_mock_channel_id()
    msg = _make_mock_message(channel_id)
    assert queue.qsize() == 0
    queue.put_nowait(msg)
    assert queue.qsize() == len(msg.data) + HEADER_SIZE
    assert queue.get_nowait() == msg
    assert queue.qsize() == 0
    queue.put_nowait(None)
    assert queue.qsize() == 0


async def test_multi_channel_queue_full() -> None:
    """Test MultiplexerMultiChannelQueue getting full."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max two mock messages per channel
    queue = MultiplexerMultiChannelQueue(msg_size * 2, msg_size, msg_size * 2)

    channel_one_id = _make_mock_channel_id()
    channel_two_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    queue.create_channel(channel_two_id, lambda _: None)

    channel_one_msg = _make_mock_message(channel_one_id)
    channel_two_msg = _make_mock_message(channel_two_id)

    queue.put_nowait(channel_one_id, channel_one_msg)
    queue.put_nowait(channel_one_id, channel_one_msg)
    with pytest.raises(asyncio.QueueFull):
        queue.put_nowait(channel_one_id, channel_one_msg)
    queue.put_nowait(channel_two_id, channel_two_msg)
    queue.put_nowait(channel_two_id, channel_two_msg)
    with pytest.raises(asyncio.QueueFull):
        queue.put_nowait(channel_two_id, channel_two_msg)

    with pytest.raises(TimeoutError):
        async with asyncio.timeout(0.1):
            await queue.put(channel_one_id, channel_one_msg)

    assert queue.size(channel_one_id) == msg_size * 2

    add_task = asyncio.create_task(queue.put(channel_one_id, channel_one_msg))
    await asyncio.sleep(0)
    assert not add_task.done()
    assert queue.get_nowait() == channel_one_msg
    await asyncio.sleep(0)
    assert add_task.done()
    assert queue.get_nowait() == channel_two_msg


async def test_multi_channel_queue_force_message_on_full() -> None:
    """Test MultiplexerMultiChannelQueue getting full and forcing a message in."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max two mock messages per channel
    queue = MultiplexerMultiChannelQueue(msg_size * 2, msg_size, msg_size * 2)

    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)

    channel_one_msg = _make_mock_message(channel_one_id)

    queue.put_nowait(channel_one_id, channel_one_msg)
    queue.put_nowait(channel_one_id, channel_one_msg)
    with pytest.raises(asyncio.QueueFull):
        queue.put_nowait(channel_one_id, channel_one_msg)

    queue.put_nowait_force(channel_one_id, channel_one_msg)
    queue.put_nowait_force(channel_one_id, None)

    assert queue.size(channel_one_id) == msg_size * 3

    assert queue.get_nowait() == channel_one_msg
    assert queue.get_nowait() == channel_one_msg
    assert queue.get_nowait() == channel_one_msg
    assert queue.get_nowait() is None

    queue.delete_channel(channel_one_id)
    with pytest.raises(RuntimeError, match="does not exist or already closed"):
        queue.put_nowait_force(channel_one_id, channel_one_msg)


async def test_multi_channel_queue_round_robin_get() -> None:
    """Test MultiplexerMultiChannelQueue round robin get."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max two mock messages per channel
    queue = MultiplexerMultiChannelQueue(msg_size * 2, msg_size, msg_size * 2)
    channel_one_id = _make_mock_channel_id()
    channel_two_id = _make_mock_channel_id()
    channel_three_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    queue.create_channel(channel_two_id, lambda _: None)
    queue.create_channel(channel_three_id, lambda _: None)

    channel_one_msg = _make_mock_message(channel_one_id)
    assert queue.empty(channel_one_id)
    await queue.put(channel_one_id, channel_one_msg)
    assert not queue.empty(channel_one_id)
    assert queue.size(channel_one_id) == len(channel_one_msg.data) + HEADER_SIZE

    channel_two_msg = _make_mock_message(channel_two_id)
    assert queue.empty(channel_two_id)
    await queue.put(channel_two_id, channel_two_msg)
    assert not queue.empty(channel_two_id)
    assert queue.size(channel_two_id) == len(channel_two_msg.data) + HEADER_SIZE

    channel_three_msg = _make_mock_message(channel_three_id)
    assert queue.empty(channel_three_id)
    queue.put_nowait(channel_three_id, channel_three_msg)
    assert not queue.empty(channel_three_id)
    assert queue.size(channel_three_id) == len(channel_three_msg.data) + HEADER_SIZE

    assert queue.get_nowait() == channel_one_msg
    assert queue.empty(channel_one_id)
    assert queue.size(channel_one_id) == 0

    assert queue.get_nowait() == channel_two_msg
    assert queue.empty(channel_two_id)
    assert queue.size(channel_two_id) == 0

    assert queue.get_nowait() == channel_three_msg
    assert queue.empty(channel_three_id)
    assert queue.size(channel_three_id) == 0

    with pytest.raises(asyncio.QueueEmpty):
        queue.get_nowait()

    with pytest.raises(TimeoutError):
        async with asyncio.timeout(0.1):
            await queue.get()

    queue.put_nowait(channel_two_id, channel_two_msg)
    queue.put_nowait(channel_three_id, channel_three_msg)
    queue.put_nowait(channel_one_id, channel_one_msg)
    queue.put_nowait(channel_one_id, channel_one_msg)
    queue.put_nowait(channel_three_id, channel_three_msg)
    queue.put_nowait(channel_two_id, channel_two_msg)

    msgs = [queue.get_nowait() for _ in range(6)]
    # Queue should be fair regardless of the order of the messages
    # coming in
    assert msgs == [
        channel_two_msg,
        channel_three_msg,
        channel_one_msg,
        channel_two_msg,
        channel_three_msg,
        channel_one_msg,
    ]

    with pytest.raises(asyncio.QueueEmpty):
        queue.get_nowait()


async def test_concurrent_get() -> None:
    """Test MultiplexerMultiChannelQueue concurrent get."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max two mock messages per channel
    queue = MultiplexerMultiChannelQueue(msg_size * 2, msg_size, msg_size * 2)
    channel_one_id = _make_mock_channel_id()
    channel_two_id = _make_mock_channel_id()
    channel_three_id = _make_mock_channel_id()

    queue.create_channel(channel_one_id, lambda _: None)
    queue.create_channel(channel_two_id, lambda _: None)
    queue.create_channel(channel_three_id, lambda _: None)

    channel_one_msg = _make_mock_message(channel_one_id)
    channel_two_msg = _make_mock_message(channel_two_id)
    channel_three_msg = _make_mock_message(channel_three_id)

    fetch_tasks = [asyncio.create_task(queue.get()) for _ in range(3)]

    await queue.put(channel_one_id, channel_one_msg)
    await queue.put(channel_two_id, channel_two_msg)
    await queue.put(channel_three_id, channel_three_msg)

    fetched_msgs = await asyncio.gather(*fetch_tasks)

    assert channel_one_msg in fetched_msgs
    assert channel_two_msg in fetched_msgs
    assert channel_three_msg in fetched_msgs

    with pytest.raises(asyncio.QueueEmpty):
        queue.get_nowait()


async def test_cancel_one_get() -> None:
    """Test the cancellation of a single `get` operation on multiplexer queue."""
    queue = MultiplexerMultiChannelQueue(100000, 10, 10000)
    reader = asyncio.create_task(queue.get())
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)

    channel_one_msg1 = _make_mock_message(channel_one_id)
    channel_one_msg2 = _make_mock_message(channel_one_id)

    await asyncio.sleep(0)

    queue.put_nowait(channel_one_id, channel_one_msg1)
    queue.put_nowait(channel_one_id, channel_one_msg2)
    reader.cancel()

    with pytest.raises(asyncio.CancelledError):
        await reader

    assert await queue.get() == channel_one_msg1


async def test_reader_cancellation() -> None:
    """
    Test behavior of the MultiplexerMultiChannelQueue when a reader task is cancelled.

     Assertions:
        - The cancelled reader task raises asyncio.CancelledError.
        - The remaining reader tasks retrieve the messages from the queue in any order.
    """
    queue = MultiplexerMultiChannelQueue(100000, 10, 10000)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    channel_one_msg1 = _make_mock_message(channel_one_id)
    channel_one_msg2 = _make_mock_message(channel_one_id)

    async with asyncio.TaskGroup() as tg:
        reader1 = tg.create_task(queue.get())
        reader2 = tg.create_task(queue.get())
        reader3 = tg.create_task(queue.get())

        await asyncio.sleep(0)

        queue.put_nowait(channel_one_id, channel_one_msg1)
        queue.put_nowait(channel_one_id, channel_one_msg2)
        reader1.cancel()

        with pytest.raises(asyncio.CancelledError):
            await reader1

        await reader3

    # Any order is fine as long as we get both messages
    # since task order is not guaranteed
    assert {reader2.result(), reader3.result()} == {channel_one_msg1, channel_one_msg2}


async def test_put_cancel_race() -> None:
    """Test race between putting messages and cancelling the put operation."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max one message
    queue = MultiplexerMultiChannelQueue(msg_size, msg_size, msg_size)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)

    channel_one_msg_1 = _make_mock_message(channel_one_id)
    channel_one_msg_2 = _make_mock_message(channel_one_id)
    channel_one_msg_3 = _make_mock_message(channel_one_id)

    queue.put_nowait(channel_one_id, channel_one_msg_1)
    assert queue.get_nowait() == channel_one_msg_1
    assert queue.empty(channel_one_id)

    put_1 = asyncio.create_task(queue.put(channel_one_id, channel_one_msg_1))
    put_2 = asyncio.create_task(queue.put(channel_one_id, channel_one_msg_2))
    put_3 = asyncio.create_task(queue.put(channel_one_id, channel_one_msg_3))

    await asyncio.sleep(0)
    assert put_1.done()
    assert not put_2.done()
    assert not put_3.done()

    put_3.cancel()
    await asyncio.sleep(0)
    assert put_3.done()
    assert queue.get_nowait() == channel_one_msg_1
    await asyncio.sleep(0)
    assert queue.get_nowait() == channel_one_msg_2

    await put_2


async def test_putters_cleaned_up_correctly_on_cancellation() -> None:
    """Test that putters are cleaned up correctly when a put operation is canceled."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max one message
    queue = MultiplexerMultiChannelQueue(msg_size, msg_size, msg_size)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    channel_one_msg_1 = _make_mock_message(channel_one_id)
    channel_one_msg_2 = _make_mock_message(channel_one_id)

    queue.put_nowait(channel_one_id, channel_one_msg_1)

    put_task = asyncio.create_task(queue.put(channel_one_id, channel_one_msg_2))
    await asyncio.sleep(0)

    # Check that the putter is correctly removed from channel putters
    # the task is canceled.
    assert len(queue._channels[channel_one_id].putters) == 1
    put_task.cancel()
    with pytest.raises(asyncio.CancelledError):
        await put_task
    assert len(queue._channels[channel_one_id].putters) == 0


async def test_getters_cleaned_up_correctly_on_cancellation() -> None:
    """Test getters are cleaned up correctly when a get operation is canceled."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max one message
    queue = MultiplexerMultiChannelQueue(msg_size, msg_size, msg_size)
    with pytest.raises(TimeoutError):
        async with asyncio.timeout(0.1):
            await queue.get()

    assert len(queue._getters) == 0


async def test_cancelled_when_putter_already_removed() -> None:
    """Test put operation is correctly cancelled when the putter is already removed."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max one message
    queue = MultiplexerMultiChannelQueue(msg_size, msg_size, msg_size)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    channel_one_msg_1 = _make_mock_message(channel_one_id)

    queue.put_nowait(channel_one_id, channel_one_msg_1)
    put_task = asyncio.create_task(queue.put(channel_one_id, channel_one_msg_1))
    await asyncio.sleep(0)

    queue.get_nowait()
    put_task.cancel()
    with pytest.raises(asyncio.CancelledError):
        await put_task


async def test_multiple_getters_waiting_multiple_putters() -> None:
    """Test that multiple getters and putters are correctly handled."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    # Max one message
    queue = MultiplexerMultiChannelQueue(msg_size, msg_size, msg_size)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    channel_one_msg_1 = _make_mock_message(channel_one_id)
    channel_one_msg_2 = _make_mock_message(channel_one_id)
    t1 = asyncio.create_task(queue.put(channel_one_id, channel_one_msg_1))
    t2 = asyncio.create_task(queue.put(channel_one_id, channel_one_msg_2))
    assert await queue.get() == channel_one_msg_1
    assert await queue.get() == channel_one_msg_2
    await t1
    await t2


async def test_get_cancelled_race() -> None:
    """Test cancelling a get operation while another get operation is in progress."""
    queue = MultiplexerMultiChannelQueue(10000000, 10, 10000)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    channel_one_msg_1 = _make_mock_message(channel_one_id)

    t1 = asyncio.create_task(queue.get())
    t2 = asyncio.create_task(queue.get())

    await asyncio.sleep(0)
    t1.cancel()
    await asyncio.sleep(0)
    assert t1.done()
    await queue.put(channel_one_id, channel_one_msg_1)
    await asyncio.sleep(0)
    assert await t2 == channel_one_msg_1


async def test_get_with_other_putters() -> None:
    """Test that a get operation is correctly handled when other putters are waiting."""
    loop = asyncio.get_running_loop()
    queue = MultiplexerMultiChannelQueue(10000000, 10, 10000)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    channel_one_msg_1 = _make_mock_message(channel_one_id)

    queue.put_nowait(channel_one_id, channel_one_msg_1)
    other_putter = loop.create_future()
    queue._channels[channel_one_id].putters.append(other_putter)

    assert await queue.get() == channel_one_msg_1
    assert other_putter.done()
    assert await other_putter is None

    await queue.put(channel_one_id, channel_one_msg_1)
    assert queue.get_nowait() == channel_one_msg_1


async def test_get_with_other_putter_already_one() -> None:
    """Test that a get operation is correctly handled when other putters are waiting."""
    loop = asyncio.get_running_loop()
    queue = MultiplexerMultiChannelQueue(10000000, 10, 10000)
    channel_one_id = _make_mock_channel_id()
    queue.create_channel(channel_one_id, lambda _: None)
    channel_one_msg_1 = _make_mock_message(channel_one_id)

    queue.put_nowait(channel_one_id, channel_one_msg_1)
    other_putter = loop.create_future()
    other_putter.set_result(None)
    queue._channels[channel_one_id].putters.append(other_putter)

    assert await queue.get() == channel_one_msg_1
    assert other_putter.done()
    assert await other_putter is None

    await queue.put(channel_one_id, channel_one_msg_1)
    assert queue.get_nowait() == channel_one_msg_1


async def test_single_channel_queue_under_water() -> None:
    """Test MultiplexerSingleChannelQueue under water."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    under_water_callbacks: list[bool] = []

    def on_under_water(under_water: bool) -> None:
        under_water_callbacks.append(under_water)

    queue = MultiplexerSingleChannelQueue(
        msg_size * 10,
        msg_size * 2,
        msg_size * 4,
        on_under_water,
    )
    channel_id = _make_mock_channel_id()
    msg = _make_mock_message(channel_id)
    assert queue.qsize() == 0
    queue.put_nowait(msg)
    assert queue.qsize() == len(msg.data) + HEADER_SIZE
    assert not under_water_callbacks
    queue.put_nowait(msg)  # now 2 messages
    assert not under_water_callbacks
    queue.put_nowait(msg)  # now 3 messages
    assert not under_water_callbacks
    queue.put_nowait(msg)  # now 4 messages -- under water
    assert under_water_callbacks == [True]
    queue.put_nowait(msg)  # now 5 messages -- still under water
    assert under_water_callbacks == [True]
    queue.get_nowait()  # now 4 messages -- have not reached low watermark
    assert under_water_callbacks == [True]
    queue.get_nowait()  # now 3 messages -- have not reached low watermark
    assert under_water_callbacks == [True]
    queue.get_nowait()  # now 2 messages -- reached low watermark
    assert under_water_callbacks == [True, False]
    queue.get_nowait()  # now 1 message -- still below low watermark
    assert under_water_callbacks == [True, False]
    queue.get_nowait()  # now 0 messages -- empty
    assert under_water_callbacks == [True, False]
    queue.put_nowait(msg)  # now 1 message -- below high watermark
    assert under_water_callbacks == [True, False]
    queue.put_nowait(msg)  # now 2 messages -- still below high watermark
    assert under_water_callbacks == [True, False]
    queue.put_nowait(msg)  # now 3 messages -- still below high watermark
    assert under_water_callbacks == [True, False]
    queue.put_nowait(msg)  # now 4 messages -- reached high watermark
    assert under_water_callbacks == [True, False, True]
    queue.get_nowait()  # now 3 messages -- below high watermark, but still above low watermark
    assert under_water_callbacks == [True, False, True]
    queue.get_nowait()  # now 2 messages -- below high watermark and below low watermark
    assert under_water_callbacks == [True, False, True, False]


async def test_multi_channel_queue_under_water() -> None:
    """Test MultiplexerMultiChannelQueue under water."""
    msg_size = MOCK_MSG_SIZE + HEADER_SIZE
    under_water_callbacks: list[bool] = []

    def on_under_water(under_water: bool) -> None:
        under_water_callbacks.append(under_water)

    queue = MultiplexerMultiChannelQueue(
        msg_size * 10,
        msg_size * 2,
        msg_size * 4,
    )
    channel_id = _make_mock_channel_id()
    queue.create_channel(channel_id, on_under_water)
    msg = _make_mock_message(channel_id)
    assert queue.empty(channel_id)
    queue.put_nowait(channel_id, msg)
    assert not under_water_callbacks
    queue.put_nowait(channel_id, msg)  # now 2 messages
    assert not under_water_callbacks
    queue.put_nowait(channel_id, msg)  # now 3 messages
    assert not under_water_callbacks
    queue.put_nowait(channel_id, msg)  # now 4 messages -- under water
    assert under_water_callbacks == [True]
    queue.put_nowait(channel_id, msg)  # now 5 messages -- still under water
    assert under_water_callbacks == [True]
    queue.get_nowait()  # now 4 messages -- have not reached low watermark
    assert under_water_callbacks == [True]
    queue.get_nowait()  # now 3 messages -- have not reached low watermark
    assert under_water_callbacks == [True]
    queue.get_nowait()  # now 2 messages -- reached low watermark
    assert under_water_callbacks == [True, False]
    queue.get_nowait()  # now 1 message -- still below low watermark
    assert under_water_callbacks == [True, False]
    queue.get_nowait()  # now 0 messages -- empty
    assert under_water_callbacks == [True, False]
    queue.put_nowait(channel_id, msg)  # now 1 message -- below high watermark
    assert under_water_callbacks == [True, False]
    queue.put_nowait(channel_id, msg)  # now 2 messages -- still below high watermark
    assert under_water_callbacks == [True, False]
    queue.put_nowait(channel_id, msg)  # now 3 messages -- still below high watermark
    assert under_water_callbacks == [True, False]
    queue.put_nowait(channel_id, msg)  # now 4 messages -- reached high watermark
    assert under_water_callbacks == [True, False, True]
    queue.get_nowait()  # now 3 messages -- below high watermark, but still above low watermark
    assert under_water_callbacks == [True, False, True]
    queue.get_nowait()  # now 2 messages -- below high watermark and below low watermark
    assert under_water_callbacks == [True, False, True, False]


async def test_put_nowait_to_non_existent_multi_channel_queue() -> None:
    """Test writing to a non-existent channel."""
    queue = MultiplexerMultiChannelQueue(100000, 10, 1000)
    channel_id = _make_mock_channel_id()
    msg = _make_mock_message(channel_id)
    with pytest.raises(RuntimeError, match=f"Channel {channel_id} does not exist"):
        queue.put_nowait(channel_id, msg)


async def test_put_to_non_existent_multi_channel_queue() -> None:
    """Test writing to a non-existent channel."""
    queue = MultiplexerMultiChannelQueue(100000, 10, 1000)
    channel_id = _make_mock_channel_id()
    msg = _make_mock_message(channel_id)
    with pytest.raises(RuntimeError, match=f"Channel {channel_id} does not exist"):
        await queue.put(channel_id, msg)


async def test_multiple_delete_channel_is_forgiving() -> None:
    """Test a channel can be deleted multiple times."""
    queue = MultiplexerMultiChannelQueue(100000, 10, 1000)
    channel_id = _make_mock_channel_id()
    queue.create_channel(channel_id, lambda _: None)
    queue.delete_channel(channel_id)
    queue.delete_channel(channel_id)


async def test_delete_channel_when_queue_is_not_empty() -> None:
    """Test a channel can be deleted when its queue is not empty."""
    queue = MultiplexerMultiChannelQueue(100000, 10, 1000)
    channel_id = _make_mock_channel_id()
    queue.create_channel(channel_id, lambda _: None)
    queue.put_nowait(channel_id, _make_mock_message(channel_id))
    queue.delete_channel(channel_id)
    assert not queue.empty(channel_id)
    assert queue.get_nowait() is not None
    queue.delete_channel(channel_id)
    assert queue.empty(channel_id)


async def test_multiple_create_channel_raises() -> None:
    """Test the same channel can only be created once."""
    queue = MultiplexerMultiChannelQueue(100000, 10, 1000)
    channel_id = _make_mock_channel_id()
    queue.create_channel(channel_id, lambda _: None)
    with pytest.raises(RuntimeError, match=f"Channel {channel_id} already exists"):
        queue.create_channel(channel_id, lambda _: None)