File: test_thread_safety.py

package info (click to toggle)
python-promise 2.3.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 400 kB
  • sloc: python: 2,681; sh: 13; makefile: 4
file content (115 lines) | stat: -rw-r--r-- 3,432 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from promise import Promise
from promise.dataloader import DataLoader
import threading



def test_promise_thread_safety():
    """
    Promise tasks should never be executed in a different thread from the one they are scheduled from,
    unless the ThreadPoolExecutor is used.

    Here we assert that the pending promise tasks on thread 1 are not executed on thread 2 as thread 2 
    resolves its own promise tasks.
    """
    event_1 = threading.Event()
    event_2 = threading.Event()

    assert_object = {'is_same_thread': True}

    def task_1():
        thread_name = threading.current_thread().name

        def then_1(value): 
          # Enqueue tasks to run later. 
          # This relies on the fact that `then` does not execute the function synchronously when called from
          # within another `then` callback function.
          promise = Promise.resolve(None).then(then_2)
          assert promise.is_pending
          event_1.set()  # Unblock main thread
          event_2.wait()  # Wait for thread 2 
       
        def then_2(value):
          assert_object['is_same_thread'] = (thread_name == threading.current_thread().name)

        promise = Promise.resolve(None).then(then_1)

    def task_2():
        promise = Promise.resolve(None).then(lambda v: None)
        promise.get()  # Drain task queue
        event_2.set()  # Unblock thread 1

    thread_1 = threading.Thread(target=task_1)
    thread_1.start()

    event_1.wait()  # Wait for Thread 1 to enqueue promise tasks

    thread_2 = threading.Thread(target=task_2)  
    thread_2.start()

    for thread in (thread_1, thread_2):
      thread.join()

    assert assert_object['is_same_thread']


def test_dataloader_thread_safety():
    """
    Dataloader should only batch `load` calls that happened on the same thread.
    
    Here we assert that `load` calls on thread 2 are not batched on thread 1 as
    thread 1 batches its own `load` calls.
    """
    def load_many(keys):
        thead_name = threading.current_thread().name
        return Promise.resolve([thead_name for key in keys])

    thread_name_loader = DataLoader(load_many)

    event_1 = threading.Event()
    event_2 = threading.Event()
    event_3 = threading.Event()

    assert_object = {
      'is_same_thread_1': True,
      'is_same_thread_2': True,
    }

    def task_1():
        @Promise.safe
        def do():
            promise = thread_name_loader.load(1)
            event_1.set()
            event_2.wait()  # Wait for thread 2 to call `load`
            assert_object['is_same_thread_1'] = (
              promise.get() == threading.current_thread().name
            )
            event_3.set()  # Unblock thread 2

        do().get()

    def task_2():
        @Promise.safe
        def do():
            promise = thread_name_loader.load(2)
            event_2.set()
            event_3.wait()  # Wait for thread 1 to run `dispatch_queue_batch`
            assert_object['is_same_thread_2'] = (
              promise.get() == threading.current_thread().name
            )
            
        do().get()

    thread_1 = threading.Thread(target=task_1)
    thread_1.start()

    event_1.wait() # Wait for thread 1 to call `load`

    thread_2 = threading.Thread(target=task_2)  
    thread_2.start()

    for thread in (thread_1, thread_2):
      thread.join()

    assert assert_object['is_same_thread_1']
    assert assert_object['is_same_thread_2']