1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
|
import asyncio
import json
import socket as pysocket
import threading
import pytest
import zmq.asyncio as zmq
from zmq import REP, REQ
from maggma.cli.distributed import find_port, manager, worker
from maggma.core import Builder
# TODO: Timeout errors?
HOSTNAME = pysocket.gethostname()
class DummyBuilderWithNoPrechunk(Builder):
def __init__(self, dummy_prechunk: bool, val: int = -1, **kwargs):
self.dummy_prechunk = dummy_prechunk
self.connected = False
self.kwargs = kwargs
self.val = val
super().__init__(sources=[], targets=[])
def connect(self):
self.connected = True
def get_items(self):
return list(range(10))
def process_items(self, items):
pass
def update_targets(self, items):
pass
class DummyBuilder(DummyBuilderWithNoPrechunk):
def prechunk(self, num_chunks):
return [{"val": i} for i in range(num_chunks)]
class DummyBuilderError(DummyBuilderWithNoPrechunk):
def prechunk(self, num_chunks):
return [{"val": i} for i in range(num_chunks)]
def get_items(self):
raise ValueError("Dummy error")
def process_items(self, items):
raise ValueError("Dummy error")
SERVER_URL = "tcp://127.0.0.1"
SERVER_PORT = 1234
@pytest.mark.xfail(raises=ValueError)
def test_wrong_worker_input(log_to_stdout):
manager(
SERVER_URL,
SERVER_PORT,
[DummyBuilder(dummy_prechunk=False)],
num_chunks=2,
num_workers=0,
)
def test_manager_and_worker(log_to_stdout):
manager_thread = threading.Thread(
target=manager,
args=(SERVER_URL, SERVER_PORT, [DummyBuilder(dummy_prechunk=False)], 5, 3),
)
manager_thread.start()
worker_threads = [threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True)) for _ in range(3)]
for worker_thread in worker_threads:
worker_thread.start()
for worker_thread in worker_threads:
worker_thread.join()
manager_thread.join()
@pytest.mark.asyncio()
async def test_manager_worker_error(log_to_stdout):
manager_thread = threading.Thread(
target=manager,
args=(SERVER_URL, SERVER_PORT, [DummyBuilder(dummy_prechunk=False)], 10, 1),
)
manager_thread.start()
context = zmq.Context()
socket = context.socket(REQ)
socket.connect(f"{SERVER_URL}:{SERVER_PORT}")
await socket.send(b"ERROR_testerror")
await asyncio.sleep(1)
manager_thread.join()
@pytest.mark.asyncio()
async def test_worker_error():
context = zmq.Context()
socket = context.socket(REP)
socket.bind(f"{SERVER_URL}:{SERVER_PORT}")
worker_task = threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True))
worker_task.start()
message = await socket.recv()
assert message == f"READY_{HOSTNAME}".encode()
dummy_work = {
"@module": "tests.cli.test_distributed",
"@class": "DummyBuilderError",
"@version": None,
"dummy_prechunk": False,
"val": 0,
}
await socket.send(json.dumps(dummy_work).encode("utf-8"))
await asyncio.sleep(1)
message = await socket.recv()
assert message.decode("utf-8") == "ERROR_Dummy error"
worker_task.join()
@pytest.mark.asyncio()
async def test_worker_exit():
context = zmq.Context()
socket = context.socket(REP)
socket.bind(f"{SERVER_URL}:{SERVER_PORT}")
worker_task = threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True))
worker_task.start()
message = await socket.recv()
assert message == f"READY_{HOSTNAME}".encode()
await asyncio.sleep(1)
await socket.send(b"EXIT")
await asyncio.sleep(1)
assert not worker_task.is_alive()
worker_task.join()
@pytest.mark.xfail()
def test_no_prechunk(caplog):
manager(
SERVER_URL,
SERVER_PORT,
[DummyBuilderWithNoPrechunk(dummy_prechunk=False)],
10,
1,
)
def test_find_port():
assert find_port() > 0
|