1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
from __future__ import annotations
import asyncio
import math
from collections import defaultdict
import pytest
from distributed.shuffle._comms import CommShardsBuffer
from distributed.utils_test import gen_test
@gen_test()
async def test_basic(tmp_path):
d = defaultdict(list)
async def send(address, shards):
d[address].extend(shards)
mc = CommShardsBuffer(send=send)
await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
await mc.flush()
assert b"".join(d["x"]) == b"0" * 2000
assert b"".join(d["y"]) == b"1" * 1000
@gen_test()
async def test_exceptions(tmp_path):
d = defaultdict(list)
async def send(address, shards):
raise Exception(123)
mc = CommShardsBuffer(send=send)
await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
while not mc._exception:
await asyncio.sleep(0.1)
with pytest.raises(Exception, match="123"):
await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
await mc.flush()
await mc.close()
@gen_test()
async def test_slow_send(tmpdir):
block_send = asyncio.Event()
block_send.set()
sending_first = asyncio.Event()
d = defaultdict(list)
async def send(address, shards):
await block_send.wait()
d[address].extend(shards)
sending_first.set()
mc = CommShardsBuffer(send=send, concurrency_limit=1)
await mc.write({"x": [b"0"], "y": [b"1"]})
await mc.write({"x": [b"0"], "y": [b"1"]})
flush_task = asyncio.create_task(mc.flush())
await sending_first.wait()
block_send.clear()
with pytest.raises(RuntimeError):
await mc.write({"x": [b"2"], "y": [b"2"]})
await flush_task
assert [b"2" not in shard for shard in d["x"]]
def gen_bytes(percentage: float) -> bytes:
num_bytes = int(math.floor(percentage * CommShardsBuffer.memory_limit))
return b"0" * num_bytes
@gen_test()
async def test_concurrent_puts():
d = defaultdict(list)
async def send(address, shards):
d[address].extend(shards)
frac = 0.1
nshards = 10
nputs = 20
payload = {x: [gen_bytes(frac)] for x in range(nshards)}
async with CommShardsBuffer(send=send) as mc:
futs = [asyncio.create_task(mc.write(payload)) for _ in range(nputs)]
await asyncio.gather(*futs)
await mc.flush()
assert not mc.shards
assert not mc.sizes
assert not mc.shards
assert not mc.sizes
assert len(d) == 10
assert sum(map(len, d[0])) == len(gen_bytes(frac)) * nputs
@gen_test()
async def test_concurrent_puts_error():
d = defaultdict(list)
counter = 0
async def send(address, shards):
nonlocal counter
counter += 1
if counter == 5:
raise OSError("error during send")
d[address].extend(shards)
frac = 0.1
nshards = 10
nputs = 20
payload = {x: [gen_bytes(frac)] for x in range(nshards)}
async with CommShardsBuffer(send=send) as mc:
futs = [asyncio.create_task(mc.write(payload)) for _ in range(nputs)]
await asyncio.gather(*futs)
await mc.flush()
with pytest.raises(OSError, match="error during send"):
mc.raise_on_exception()
assert not mc.shards
assert not mc.sizes
|