1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import asyncio
import contextvars
import sys
import threading
import time
import pytest
from asgiref.sync import ThreadSensitiveContext, async_to_sync, sync_to_async
foo: "contextvars.ContextVar[str]" = contextvars.ContextVar("foo")
@pytest.mark.asyncio
async def test_thread_sensitive_with_context_different():
result_1 = {}
result_2 = {}
@sync_to_async
def store_thread(result):
result["thread"] = threading.current_thread()
async def fn(result):
async with ThreadSensitiveContext():
await store_thread(result)
# Run it (in true parallel!)
await asyncio.wait(
[asyncio.create_task(fn(result_1)), asyncio.create_task(fn(result_2))]
)
# They should not have run in the main thread, and on different threads
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] != result_2["thread"]
@pytest.mark.asyncio
async def test_sync_to_async_contextvars():
"""
Tests to make sure that contextvars from the calling context are
present in the called context, and that any changes in the called context
are then propagated back to the calling context.
"""
# Define sync function
def sync_function():
time.sleep(1)
assert foo.get() == "bar"
foo.set("baz")
return 42
# Ensure outermost detection works
# Wrap it
foo.set("bar")
async_function = sync_to_async(sync_function)
assert await async_function() == 42
assert foo.get() == "baz"
@pytest.mark.asyncio
async def test_sync_to_async_contextvars_with_custom_context():
"""
Passing a custom context to `sync_to_async` ensures that changes to context
variables within the synchronous function are isolated to the provided
context and do not affect the caller's context. Specifically, verifies that
modifications to a context variable inside the sync function are reflected
only in the custom context and not in the outer context.
"""
def sync_function():
time.sleep(1)
assert foo.get() == "bar"
foo.set("baz")
return 42
foo.set("bar")
context = contextvars.copy_context()
async_function = sync_to_async(sync_function, context=context)
assert await async_function() == 42
# Current context remains unchanged.
assert foo.get() == "bar"
# Custom context reflects the changes made within the sync function.
assert context.get(foo) == "baz"
@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11")
async def test_sync_to_async_contextvars_with_custom_context_and_parallel_tasks():
"""
Using a custom context with `sync_to_async` and asyncio tasks isolates
contextvars changes, leaving the original context unchanged and reflecting
all modifications in the custom context.
"""
foo.set("")
def sync_function():
foo.set(foo.get() + "1")
return 1
async def async_function():
foo.set(foo.get() + "1")
return 1
context = contextvars.copy_context()
await asyncio.gather(
sync_to_async(sync_function, context=context)(),
sync_to_async(sync_function, context=context)(),
asyncio.create_task(async_function(), context=context),
asyncio.create_task(async_function(), context=context),
)
# Current context remains unchanged
assert foo.get() == ""
# Custom context reflects the changes made within all the gathered tasks.
assert context.get(foo) == "1111"
def test_async_to_sync_contextvars():
"""
Tests to make sure that contextvars from the calling context are
present in the called context, and that any changes in the called context
are then propagated back to the calling context.
"""
# Define async function
async def async_function():
await asyncio.sleep(1)
assert foo.get() == "bar"
foo.set("baz")
return 42
# Ensure outermost detection works
# Wrap it
foo.set("bar")
sync_function = async_to_sync(async_function)
assert sync_function() == 42
assert foo.get() == "baz"
|