1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
import asyncio
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, patch
import pytest
import zarr
from zarr.core.sync import (
SyncError,
SyncMixin,
_get_executor,
_get_lock,
_get_loop,
cleanup_resources,
loop,
sync,
)
@pytest.fixture(params=[True, False])
def sync_loop(request: pytest.FixtureRequest) -> asyncio.AbstractEventLoop | None:
if request.param is True:
return _get_loop()
else:
return None
@pytest.fixture
def clean_state():
# use this fixture to make sure no existing threads/loops exist in zarr.core.sync
cleanup_resources()
yield
cleanup_resources()
def test_get_loop() -> None:
# test that calling _get_loop() twice returns the same loop
loop = _get_loop()
loop2 = _get_loop()
assert loop is loop2
def test_get_lock() -> None:
# test that calling _get_lock() twice returns the same lock
lock = _get_lock()
lock2 = _get_lock()
assert lock is lock2
def test_sync(sync_loop: asyncio.AbstractEventLoop | None) -> None:
foo = AsyncMock(return_value="foo")
assert sync(foo(), loop=sync_loop) == "foo"
foo.assert_awaited_once()
def test_sync_raises(sync_loop: asyncio.AbstractEventLoop | None) -> None:
foo = AsyncMock(side_effect=ValueError("foo-bar"))
with pytest.raises(ValueError, match="foo-bar"):
sync(foo(), loop=sync_loop)
foo.assert_awaited_once()
def test_sync_timeout() -> None:
duration = 0.02
async def foo() -> None:
await asyncio.sleep(duration)
with pytest.raises(asyncio.TimeoutError):
sync(foo(), timeout=duration / 10)
def test_sync_raises_if_no_coroutine(sync_loop: asyncio.AbstractEventLoop | None) -> None:
def foo() -> str:
return "foo"
with pytest.raises(TypeError):
sync(foo(), loop=sync_loop) # type: ignore[arg-type]
@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited")
def test_sync_raises_if_loop_is_closed() -> None:
loop = _get_loop()
foo = AsyncMock(return_value="foo")
with patch.object(loop, "is_closed", return_value=True):
with pytest.raises(RuntimeError):
sync(foo(), loop=loop)
foo.assert_not_awaited()
@pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning")
@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited")
def test_sync_raises_if_calling_sync_from_within_a_running_loop(
sync_loop: asyncio.AbstractEventLoop | None,
) -> None:
def foo() -> str:
# technically, this should be an async function but doing that
# yields a warning because it is never awaited by the inner function
return "foo"
async def bar() -> str:
return sync(foo(), loop=sync_loop) # type: ignore[arg-type]
with pytest.raises(SyncError):
sync(bar(), loop=sync_loop)
@pytest.mark.filterwarnings("ignore:coroutine.*was never awaited")
def test_sync_raises_if_loop_is_invalid_type() -> None:
foo = AsyncMock(return_value="foo")
with pytest.raises(TypeError):
sync(foo(), loop=1) # type: ignore[arg-type]
foo.assert_not_awaited()
def test_sync_mixin(sync_loop) -> None:
class AsyncFoo:
def __init__(self) -> None:
pass
async def foo(self) -> str:
return "foo"
async def bar(self) -> AsyncGenerator:
for i in range(10):
yield i
class SyncFoo(SyncMixin):
def __init__(self, async_foo: AsyncFoo) -> None:
self._async_foo = async_foo
def foo(self) -> str:
return self._sync(self._async_foo.foo())
def bar(self) -> list[int]:
return self._sync_iter(self._async_foo.bar())
async_foo = AsyncFoo()
foo = SyncFoo(async_foo)
assert foo.foo() == "foo"
assert foo.bar() == list(range(10))
@pytest.mark.parametrize("workers", [None, 1, 2])
def test_threadpool_executor(clean_state, workers: int | None) -> None:
with zarr.config.set({"threading.max_workers": workers}):
_ = zarr.zeros(shape=(1,)) # trigger executor creation
assert loop != [None] # confirm loop was created
if workers is None:
# confirm no executor was created if no workers were specified
# (this is the default behavior)
assert loop[0]._default_executor is None
else:
# confirm executor was created and attached to loop as the default executor
# note: python doesn't have a direct way to get the default executor so we
# use the private attribute
assert _get_executor() is loop[0]._default_executor
assert _get_executor()._max_workers == workers
def test_cleanup_resources_idempotent() -> None:
_get_executor() # trigger resource creation (iothread, loop, thread-pool)
cleanup_resources()
cleanup_resources()
|