1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
import asyncio
import contextvars
import threading
import time
import pytest
from asgiref.sync import ThreadSensitiveContext, async_to_sync, sync_to_async
foo: "contextvars.ContextVar[str]" = contextvars.ContextVar("foo")
@pytest.mark.asyncio
async def test_thread_sensitive_with_context_different():
result_1 = {}
result_2 = {}
@sync_to_async
def store_thread(result):
result["thread"] = threading.current_thread()
async def fn(result):
async with ThreadSensitiveContext():
await store_thread(result)
# Run it (in true parallel!)
await asyncio.wait(
[asyncio.create_task(fn(result_1)), asyncio.create_task(fn(result_2))]
)
# They should not have run in the main thread, and on different threads
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] != result_2["thread"]
@pytest.mark.asyncio
async def test_sync_to_async_contextvars():
"""
Tests to make sure that contextvars from the calling context are
present in the called context, and that any changes in the called context
are then propagated back to the calling context.
"""
# Define sync function
def sync_function():
time.sleep(1)
assert foo.get() == "bar"
foo.set("baz")
return 42
# Ensure outermost detection works
# Wrap it
foo.set("bar")
async_function = sync_to_async(sync_function)
assert await async_function() == 42
assert foo.get() == "baz"
def test_async_to_sync_contextvars():
"""
Tests to make sure that contextvars from the calling context are
present in the called context, and that any changes in the called context
are then propagated back to the calling context.
"""
# Define sync function
async def async_function():
await asyncio.sleep(1)
assert foo.get() == "bar"
foo.set("baz")
return 42
# Ensure outermost detection works
# Wrap it
foo.set("bar")
sync_function = async_to_sync(async_function)
assert sync_function() == 42
assert foo.get() == "baz"
|