File: test_distributed.py

package info (click to toggle)
python-maggma 0.70.0-7
  • links: PTS, VCS
  • area: main
  • in suites: forky
  • size: 1,416 kB
  • sloc: python: 10,150; makefile: 12
file content (167 lines) | stat: -rw-r--r-- 4,015 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import asyncio
import json
import socket as pysocket
import threading

import pytest
import zmq.asyncio as zmq
from zmq import REP, REQ

from maggma.cli.distributed import find_port, manager, worker
from maggma.core import Builder

# TODO: Timeout errors?

HOSTNAME = pysocket.gethostname()


class DummyBuilderWithNoPrechunk(Builder):
    def __init__(self, dummy_prechunk: bool, val: int = -1, **kwargs):
        self.dummy_prechunk = dummy_prechunk
        self.connected = False
        self.kwargs = kwargs
        self.val = val
        super().__init__(sources=[], targets=[])

    def connect(self):
        self.connected = True

    def get_items(self):
        return list(range(10))

    def process_items(self, items):
        pass

    def update_targets(self, items):
        pass


class DummyBuilder(DummyBuilderWithNoPrechunk):
    def prechunk(self, num_chunks):
        return [{"val": i} for i in range(num_chunks)]


class DummyBuilderError(DummyBuilderWithNoPrechunk):
    def prechunk(self, num_chunks):
        return [{"val": i} for i in range(num_chunks)]

    def get_items(self):
        raise ValueError("Dummy error")

    def process_items(self, items):
        raise ValueError("Dummy error")


SERVER_URL = "tcp://127.0.0.1"
SERVER_PORT = 1234


@pytest.mark.xfail(raises=ValueError)
def test_wrong_worker_input(log_to_stdout):
    manager(
        SERVER_URL,
        SERVER_PORT,
        [DummyBuilder(dummy_prechunk=False)],
        num_chunks=2,
        num_workers=0,
    )


def test_manager_and_worker(log_to_stdout):
    manager_thread = threading.Thread(
        target=manager,
        args=(SERVER_URL, SERVER_PORT, [DummyBuilder(dummy_prechunk=False)], 5, 3),
    )
    manager_thread.start()

    worker_threads = [threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True)) for _ in range(3)]

    for worker_thread in worker_threads:
        worker_thread.start()

    for worker_thread in worker_threads:
        worker_thread.join()

    manager_thread.join()


@pytest.mark.asyncio()
async def test_manager_worker_error(log_to_stdout):
    manager_thread = threading.Thread(
        target=manager,
        args=(SERVER_URL, SERVER_PORT, [DummyBuilder(dummy_prechunk=False)], 10, 1),
    )
    manager_thread.start()

    context = zmq.Context()
    socket = context.socket(REQ)
    socket.connect(f"{SERVER_URL}:{SERVER_PORT}")

    await socket.send(b"ERROR_testerror")
    await asyncio.sleep(1)

    manager_thread.join()


@pytest.mark.asyncio()
async def test_worker_error():
    context = zmq.Context()
    socket = context.socket(REP)
    socket.bind(f"{SERVER_URL}:{SERVER_PORT}")

    worker_task = threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True))

    worker_task.start()

    message = await socket.recv()
    assert message == f"READY_{HOSTNAME}".encode()

    dummy_work = {
        "@module": "tests.cli.test_distributed",
        "@class": "DummyBuilderError",
        "@version": None,
        "dummy_prechunk": False,
        "val": 0,
    }

    await socket.send(json.dumps(dummy_work).encode("utf-8"))
    await asyncio.sleep(1)
    message = await socket.recv()
    assert message.decode("utf-8") == "ERROR_Dummy error"

    worker_task.join()


@pytest.mark.asyncio()
async def test_worker_exit():
    context = zmq.Context()
    socket = context.socket(REP)
    socket.bind(f"{SERVER_URL}:{SERVER_PORT}")

    worker_task = threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True))

    worker_task.start()

    message = await socket.recv()
    assert message == f"READY_{HOSTNAME}".encode()
    await asyncio.sleep(1)
    await socket.send(b"EXIT")
    await asyncio.sleep(1)
    assert not worker_task.is_alive()

    worker_task.join()


@pytest.mark.xfail()
def test_no_prechunk(caplog):
    manager(
        SERVER_URL,
        SERVER_PORT,
        [DummyBuilderWithNoPrechunk(dummy_prechunk=False)],
        10,
        1,
    )


def test_find_port():
    assert find_port() > 0