1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
|
from promise import Promise
from promise.dataloader import DataLoader
import threading
def test_promise_thread_safety():
"""
Promise tasks should never be executed in a different thread from the one they are scheduled from,
unless the ThreadPoolExecutor is used.
Here we assert that the pending promise tasks on thread 1 are not executed on thread 2 as thread 2
resolves its own promise tasks.
"""
event_1 = threading.Event()
event_2 = threading.Event()
assert_object = {'is_same_thread': True}
def task_1():
thread_name = threading.current_thread().name
def then_1(value):
# Enqueue tasks to run later.
# This relies on the fact that `then` does not execute the function synchronously when called from
# within another `then` callback function.
promise = Promise.resolve(None).then(then_2)
assert promise.is_pending
event_1.set() # Unblock main thread
event_2.wait() # Wait for thread 2
def then_2(value):
assert_object['is_same_thread'] = (thread_name == threading.current_thread().name)
promise = Promise.resolve(None).then(then_1)
def task_2():
promise = Promise.resolve(None).then(lambda v: None)
promise.get() # Drain task queue
event_2.set() # Unblock thread 1
thread_1 = threading.Thread(target=task_1)
thread_1.start()
event_1.wait() # Wait for Thread 1 to enqueue promise tasks
thread_2 = threading.Thread(target=task_2)
thread_2.start()
for thread in (thread_1, thread_2):
thread.join()
assert assert_object['is_same_thread']
def test_dataloader_thread_safety():
"""
Dataloader should only batch `load` calls that happened on the same thread.
Here we assert that `load` calls on thread 2 are not batched on thread 1 as
thread 1 batches its own `load` calls.
"""
def load_many(keys):
thead_name = threading.current_thread().name
return Promise.resolve([thead_name for key in keys])
thread_name_loader = DataLoader(load_many)
event_1 = threading.Event()
event_2 = threading.Event()
event_3 = threading.Event()
assert_object = {
'is_same_thread_1': True,
'is_same_thread_2': True,
}
def task_1():
@Promise.safe
def do():
promise = thread_name_loader.load(1)
event_1.set()
event_2.wait() # Wait for thread 2 to call `load`
assert_object['is_same_thread_1'] = (
promise.get() == threading.current_thread().name
)
event_3.set() # Unblock thread 2
do().get()
def task_2():
@Promise.safe
def do():
promise = thread_name_loader.load(2)
event_2.set()
event_3.wait() # Wait for thread 1 to run `dispatch_queue_batch`
assert_object['is_same_thread_2'] = (
promise.get() == threading.current_thread().name
)
do().get()
thread_1 = threading.Thread(target=task_1)
thread_1.start()
event_1.wait() # Wait for thread 1 to call `load`
thread_2 = threading.Thread(target=task_2)
thread_2.start()
for thread in (thread_1, thread_2):
thread.join()
assert assert_object['is_same_thread_1']
assert assert_object['is_same_thread_2']
|