File: test_sync_contextvars.py

package info (click to toggle)
python-asgiref 3.11.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 440 kB
  • sloc: python: 2,761; makefile: 19
file content (140 lines) | stat: -rw-r--r-- 4,162 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import asyncio
import contextvars
import sys
import threading
import time

import pytest

from asgiref.sync import ThreadSensitiveContext, async_to_sync, sync_to_async

foo: "contextvars.ContextVar[str]" = contextvars.ContextVar("foo")


@pytest.mark.asyncio
async def test_thread_sensitive_with_context_different():
    result_1 = {}
    result_2 = {}

    @sync_to_async
    def store_thread(result):
        result["thread"] = threading.current_thread()

    async def fn(result):
        async with ThreadSensitiveContext():
            await store_thread(result)

    # Run it (in true parallel!)
    await asyncio.wait(
        [asyncio.create_task(fn(result_1)), asyncio.create_task(fn(result_2))]
    )

    # They should not have run in the main thread, and on different threads
    assert result_1["thread"] != threading.current_thread()
    assert result_1["thread"] != result_2["thread"]


@pytest.mark.asyncio
async def test_sync_to_async_contextvars():
    """
    Tests to make sure that contextvars from the calling context are
    present in the called context, and that any changes in the called context
    are then propagated back to the calling context.
    """
    # Define sync function
    def sync_function():
        time.sleep(1)
        assert foo.get() == "bar"
        foo.set("baz")
        return 42

    # Ensure outermost detection works
    # Wrap it
    foo.set("bar")
    async_function = sync_to_async(sync_function)
    assert await async_function() == 42
    assert foo.get() == "baz"


@pytest.mark.asyncio
async def test_sync_to_async_contextvars_with_custom_context():
    """
    Passing a custom context to `sync_to_async` ensures that changes to context
    variables within the synchronous function are isolated to the provided
    context and do not affect the caller's context. Specifically, verifies that
    modifications to a context variable inside the sync function are reflected
    only in the custom context and not in the outer context.
    """

    def sync_function():
        time.sleep(1)
        assert foo.get() == "bar"
        foo.set("baz")
        return 42

    foo.set("bar")
    context = contextvars.copy_context()

    async_function = sync_to_async(sync_function, context=context)
    assert await async_function() == 42

    # Current context remains unchanged.
    assert foo.get() == "bar"

    # Custom context reflects the changes made within the sync function.
    assert context.get(foo) == "baz"


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11")
async def test_sync_to_async_contextvars_with_custom_context_and_parallel_tasks():
    """
    Using a custom context with `sync_to_async` and asyncio tasks isolates
    contextvars changes, leaving the original context unchanged and reflecting
    all modifications in the custom context.
    """
    foo.set("")

    def sync_function():
        foo.set(foo.get() + "1")
        return 1

    async def async_function():
        foo.set(foo.get() + "1")
        return 1

    context = contextvars.copy_context()

    await asyncio.gather(
        sync_to_async(sync_function, context=context)(),
        sync_to_async(sync_function, context=context)(),
        asyncio.create_task(async_function(), context=context),
        asyncio.create_task(async_function(), context=context),
    )

    # Current context remains unchanged
    assert foo.get() == ""

    # Custom context reflects the changes made within all the gathered tasks.
    assert context.get(foo) == "1111"


def test_async_to_sync_contextvars():
    """
    Tests to make sure that contextvars from the calling context are
    present in the called context, and that any changes in the called context
    are then propagated back to the calling context.
    """
    # Define async function
    async def async_function():
        await asyncio.sleep(1)
        assert foo.get() == "bar"
        foo.set("baz")
        return 42

    # Ensure outermost detection works
    # Wrap it
    foo.set("bar")
    sync_function = async_to_sync(async_function)
    assert sync_function() == 42
    assert foo.get() == "baz"