1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
|
from __future__ import annotations
import asyncio
from unittest import mock
import pytest
from dask.optimization import SubgraphCallable
from distributed.core import ConnectionPool
from distributed.utils_comm import (
WrappedKey,
gather_from_workers,
pack_data,
retry,
subs_multiple,
unpack_remotedata,
)
from distributed.utils_test import BrokenComm, gen_cluster
def test_pack_data():
data = {"x": 1}
assert pack_data(("x", "y"), data) == (1, "y")
assert pack_data({"a": "x", "b": "y"}, data) == {"a": 1, "b": "y"}
assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"}
def test_subs_multiple():
data = {"x": 1, "y": 2}
assert subs_multiple((sum, [0, "x", "y", "z"]), data) == (sum, [0, 1, 2, "z"])
assert subs_multiple((sum, [0, ["x", "y", "z"]]), data) == (sum, [0, [1, 2, "z"]])
dsk = {"a": (sum, ["x", "y"])}
assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])}
# Tuple key
data = {"x": 1, ("y", 0): 2}
dsk = {"a": (sum, ["x", ("y", 0)])}
assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])}
@gen_cluster(client=True)
async def test_gather_from_workers_permissive(c, s, a, b):
rpc = await ConnectionPool()
x = await c.scatter({"x": 1}, workers=a.address)
data, missing, bad_workers = await gather_from_workers(
{"x": [a.address], "y": [b.address]}, rpc=rpc
)
assert data == {"x": 1}
assert list(missing) == ["y"]
class BrokenConnectionPool(ConnectionPool):
async def connect(self, *args, **kwargs):
return BrokenComm()
@gen_cluster(client=True)
async def test_gather_from_workers_permissive_flaky(c, s, a, b):
x = await c.scatter({"x": 1}, workers=a.address)
rpc = await BrokenConnectionPool()
data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc)
assert missing == {"x": [a.address]}
assert bad_workers == [a.address]
def test_retry_no_exception(cleanup):
n_calls = 0
retval = object()
async def coro():
nonlocal n_calls
n_calls += 1
return retval
async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)
assert asyncio.run(f()) is retval
assert n_calls == 1
def test_retry0_raises_immediately(cleanup):
# test that using max_reties=0 raises after 1 call
n_calls = 0
async def coro():
nonlocal n_calls
n_calls += 1
raise RuntimeError(f"RT_ERROR {n_calls}")
async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)
with pytest.raises(RuntimeError, match="RT_ERROR 1"):
asyncio.run(f())
assert n_calls == 1
def test_retry_does_retry_and_sleep(cleanup):
# test the retry and sleep pattern of `retry`
n_calls = 0
class MyEx(Exception):
pass
async def coro():
nonlocal n_calls
n_calls += 1
raise MyEx(f"RT_ERROR {n_calls}")
sleep_calls = []
async def my_sleep(amount):
sleep_calls.append(amount)
return
async def f():
return await retry(
coro,
retry_on_exceptions=(MyEx,),
count=5,
delay_min=1.0,
delay_max=6.0,
jitter_fraction=0.0,
)
with mock.patch("asyncio.sleep", my_sleep):
with pytest.raises(MyEx, match="RT_ERROR 6"):
asyncio.run(f())
assert n_calls == 6
assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0]
def test_unpack_remotedata():
def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None:
if len(keys1) != len(keys2):
assert False
if not keys1:
assert True
if not all(isinstance(k, WrappedKey) for k in keys1 & keys2):
assert False
assert sorted([k.key for k in keys1]) == sorted([k.key for k in keys2])
assert unpack_remotedata(1) == (1, set())
assert unpack_remotedata(()) == ((), set())
res, keys = unpack_remotedata(WrappedKey("mykey"))
assert res == "mykey"
assert_eq(keys, {WrappedKey("mykey")})
# Check unpack of SC that contains a wrapped key
sc = SubgraphCallable({"key": (WrappedKey("data"),)}, outkey="key", inkeys=["arg1"])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res[0] != sc # Notice, the first item (the SC) has been changed
assert res[1:] == ("arg1", "data")
assert_eq(keys, {WrappedKey("data")})
# Check unpack of SC when it takes a wrapped key as argument
sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[WrappedKey("arg1")])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed
assert_eq(keys, set())
|