File: test_utils_comm.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (173 lines) | stat: -rw-r--r-- 4,803 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from __future__ import annotations

import asyncio
from unittest import mock

import pytest

from dask.optimization import SubgraphCallable

from distributed.core import ConnectionPool
from distributed.utils_comm import (
    WrappedKey,
    gather_from_workers,
    pack_data,
    retry,
    subs_multiple,
    unpack_remotedata,
)
from distributed.utils_test import BrokenComm, gen_cluster


def test_pack_data():
    data = {"x": 1}
    assert pack_data(("x", "y"), data) == (1, "y")
    assert pack_data({"a": "x", "b": "y"}, data) == {"a": 1, "b": "y"}
    assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"}


def test_subs_multiple():
    data = {"x": 1, "y": 2}
    assert subs_multiple((sum, [0, "x", "y", "z"]), data) == (sum, [0, 1, 2, "z"])
    assert subs_multiple((sum, [0, ["x", "y", "z"]]), data) == (sum, [0, [1, 2, "z"]])

    dsk = {"a": (sum, ["x", "y"])}
    assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])}

    # Tuple key
    data = {"x": 1, ("y", 0): 2}
    dsk = {"a": (sum, ["x", ("y", 0)])}
    assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])}


@gen_cluster(client=True)
async def test_gather_from_workers_permissive(c, s, a, b):
    rpc = await ConnectionPool()
    x = await c.scatter({"x": 1}, workers=a.address)

    data, missing, bad_workers = await gather_from_workers(
        {"x": [a.address], "y": [b.address]}, rpc=rpc
    )

    assert data == {"x": 1}
    assert list(missing) == ["y"]


class BrokenConnectionPool(ConnectionPool):
    async def connect(self, *args, **kwargs):
        return BrokenComm()


@gen_cluster(client=True)
async def test_gather_from_workers_permissive_flaky(c, s, a, b):
    x = await c.scatter({"x": 1}, workers=a.address)

    rpc = await BrokenConnectionPool()
    data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc)

    assert missing == {"x": [a.address]}
    assert bad_workers == [a.address]


def test_retry_no_exception(cleanup):
    n_calls = 0
    retval = object()

    async def coro():
        nonlocal n_calls
        n_calls += 1
        return retval

    async def f():
        return await retry(coro, count=0, delay_min=-1, delay_max=-1)

    assert asyncio.run(f()) is retval
    assert n_calls == 1


def test_retry0_raises_immediately(cleanup):
    # test that using max_reties=0 raises after 1 call

    n_calls = 0

    async def coro():
        nonlocal n_calls
        n_calls += 1
        raise RuntimeError(f"RT_ERROR {n_calls}")

    async def f():
        return await retry(coro, count=0, delay_min=-1, delay_max=-1)

    with pytest.raises(RuntimeError, match="RT_ERROR 1"):
        asyncio.run(f())

    assert n_calls == 1


def test_retry_does_retry_and_sleep(cleanup):
    # test the retry and sleep pattern of `retry`
    n_calls = 0

    class MyEx(Exception):
        pass

    async def coro():
        nonlocal n_calls
        n_calls += 1
        raise MyEx(f"RT_ERROR {n_calls}")

    sleep_calls = []

    async def my_sleep(amount):
        sleep_calls.append(amount)
        return

    async def f():
        return await retry(
            coro,
            retry_on_exceptions=(MyEx,),
            count=5,
            delay_min=1.0,
            delay_max=6.0,
            jitter_fraction=0.0,
        )

    with mock.patch("asyncio.sleep", my_sleep):
        with pytest.raises(MyEx, match="RT_ERROR 6"):
            asyncio.run(f())

    assert n_calls == 6
    assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0]


def test_unpack_remotedata():
    def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None:
        if len(keys1) != len(keys2):
            assert False
        if not keys1:
            assert True
        if not all(isinstance(k, WrappedKey) for k in keys1 & keys2):
            assert False
        assert sorted([k.key for k in keys1]) == sorted([k.key for k in keys2])

    assert unpack_remotedata(1) == (1, set())
    assert unpack_remotedata(()) == ((), set())

    res, keys = unpack_remotedata(WrappedKey("mykey"))
    assert res == "mykey"
    assert_eq(keys, {WrappedKey("mykey")})

    # Check unpack of SC that contains a wrapped key
    sc = SubgraphCallable({"key": (WrappedKey("data"),)}, outkey="key", inkeys=["arg1"])
    dsk = (sc, "arg1")
    res, keys = unpack_remotedata(dsk)
    assert res[0] != sc  # Notice, the first item (the SC) has been changed
    assert res[1:] == ("arg1", "data")
    assert_eq(keys, {WrappedKey("data")})

    # Check unpack of SC when it takes a wrapped key as argument
    sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[WrappedKey("arg1")])
    dsk = (sc, "arg1")
    res, keys = unpack_remotedata(dsk)
    assert res == (sc, "arg1")  # Notice, the first item (the SC) has NOT been changed
    assert_eq(keys, set())