1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
import asyncio
import sys
import threading
from unittest import mock
if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
else:
from asyncio import timeout as asyncio_timeout # pragma: no cover
import pytest
from bellows.thread import EventLoopThread, ThreadsafeProxy
async def test_thread_start(monkeypatch):
current_loop = asyncio.get_event_loop()
loopmock = mock.MagicMock()
monkeypatch.setattr(asyncio, "new_event_loop", lambda: loopmock)
monkeypatch.setattr(asyncio, "set_event_loop", lambda loop: None)
def mockrun(task):
future = asyncio.run_coroutine_threadsafe(task, loop=current_loop)
return future.result(1)
loopmock.run_until_complete.side_effect = mockrun
thread = EventLoopThread()
thread_complete = await thread.start()
await thread_complete
assert loopmock.run_until_complete.call_count == 1
assert loopmock.run_forever.call_count == 1
assert loopmock.close.call_count == 1
class ExceptionCollector:
def __init__(self):
self.exceptions = []
def __call__(self, thread_loop, context):
exc = context.get("exception") or Exception(context["message"])
self.exceptions.append(exc)
@pytest.fixture
async def thread():
thread = EventLoopThread()
await thread.start()
thread.loop.call_soon_threadsafe(
thread.loop.set_exception_handler, ExceptionCollector()
)
yield thread
thread.force_stop()
if thread.thread_complete is not None:
async with asyncio_timeout(1):
await thread.thread_complete
[t.join(1) for t in threading.enumerate() if "bellows" in t.name]
threads = [t for t in threading.enumerate() if "bellows" in t.name]
assert len(threads) == 0
async def yield_other_thread(thread):
await thread.run_coroutine_threadsafe(asyncio.sleep(0))
exception_collector = thread.loop.get_exception_handler()
if exception_collector.exceptions:
raise exception_collector.exceptions[0]
async def test_thread_loop(thread):
async def test_coroutine():
return mock.sentinel.result
future = asyncio.run_coroutine_threadsafe(test_coroutine(), loop=thread.loop)
result = await asyncio.wrap_future(future, loop=asyncio.get_event_loop())
assert result is mock.sentinel.result
async def test_thread_double_start(thread):
previous_loop = thread.loop
await thread.start()
if sys.version_info[:2] >= (3, 6):
threads = [t for t in threading.enumerate() if "bellows" in t.name]
assert len(threads) == 1
assert thread.loop is previous_loop
async def test_thread_already_stopped(thread):
thread.force_stop()
thread.force_stop()
async def test_thread_run_coroutine_threadsafe(thread):
inner_loop = None
async def test_coroutine():
nonlocal inner_loop
inner_loop = asyncio.get_event_loop()
return mock.sentinel.result
result = await thread.run_coroutine_threadsafe(test_coroutine())
assert result is mock.sentinel.result
assert inner_loop is thread.loop
async def test_proxy_callback(thread):
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, thread.loop)
obj.test.return_value = None
proxy.test()
await yield_other_thread(thread)
assert obj.test.call_count == 1
async def test_proxy_async(thread):
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, thread.loop)
call_count = 0
async def magic():
nonlocal thread, call_count
assert asyncio.get_event_loop() == thread.loop
call_count += 1
return mock.sentinel.result
obj.test = magic
result = await proxy.test()
assert call_count == 1
assert result == mock.sentinel.result
async def test_proxy_bad_function(thread):
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, thread.loop)
obj.test.return_value = mock.sentinel.value
with pytest.raises(TypeError):
proxy.test()
await yield_other_thread(thread)
async def test_proxy_not_function():
loop = asyncio.get_event_loop()
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, loop)
obj.test = mock.sentinel.value
with pytest.raises(TypeError):
proxy.test
async def test_proxy_no_thread():
loop = asyncio.get_event_loop()
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, loop)
proxy.test()
assert obj.test.call_count == 1
async def test_proxy_loop_closed():
loop = asyncio.new_event_loop()
obj = mock.MagicMock()
proxy = ThreadsafeProxy(obj, loop)
loop.close()
proxy.test()
assert obj.test.call_count == 0
async def test_thread_task_cancellation_after_stop(thread):
loop = asyncio.get_event_loop()
obj = mock.MagicMock()
async def wait_forever():
return await thread.loop.create_future()
obj.wait_forever = wait_forever
# Stop the thread while we're waiting
loop.call_later(0.1, thread.force_stop)
proxy = ThreadsafeProxy(obj, thread.loop)
# The cancellation should propagate to the outer event loop
with pytest.raises(asyncio.CancelledError):
# This will stall forever without the patch
async with asyncio_timeout(1):
await proxy.wait_forever()
|