File: test_comm_buffer.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (136 lines) | stat: -rw-r--r-- 3,334 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import asyncio
import math
from collections import defaultdict

import pytest

from distributed.shuffle._comms import CommShardsBuffer
from distributed.utils_test import gen_test


@gen_test()
async def test_basic(tmp_path):
    d = defaultdict(list)

    async def send(address, shards):
        d[address].extend(shards)

    mc = CommShardsBuffer(send=send)
    await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
    await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

    await mc.flush()

    assert b"".join(d["x"]) == b"0" * 2000
    assert b"".join(d["y"]) == b"1" * 1000


@gen_test()
async def test_exceptions(tmp_path):
    d = defaultdict(list)

    async def send(address, shards):
        raise Exception(123)

    mc = CommShardsBuffer(send=send)
    await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

    while not mc._exception:
        await asyncio.sleep(0.1)

    with pytest.raises(Exception, match="123"):
        await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

    await mc.flush()

    await mc.close()


@gen_test()
async def test_slow_send(tmpdir):
    block_send = asyncio.Event()
    block_send.set()
    sending_first = asyncio.Event()
    d = defaultdict(list)

    async def send(address, shards):
        await block_send.wait()
        d[address].extend(shards)
        sending_first.set()

    mc = CommShardsBuffer(send=send, concurrency_limit=1)
    await mc.write({"x": [b"0"], "y": [b"1"]})
    await mc.write({"x": [b"0"], "y": [b"1"]})
    flush_task = asyncio.create_task(mc.flush())
    await sending_first.wait()
    block_send.clear()

    with pytest.raises(RuntimeError):
        await mc.write({"x": [b"2"], "y": [b"2"]})
        await flush_task

    assert [b"2" not in shard for shard in d["x"]]


def gen_bytes(percentage: float) -> bytes:
    num_bytes = int(math.floor(percentage * CommShardsBuffer.memory_limit))
    return b"0" * num_bytes


@gen_test()
async def test_concurrent_puts():
    d = defaultdict(list)

    async def send(address, shards):
        d[address].extend(shards)

    frac = 0.1
    nshards = 10
    nputs = 20
    payload = {x: [gen_bytes(frac)] for x in range(nshards)}

    async with CommShardsBuffer(send=send) as mc:
        futs = [asyncio.create_task(mc.write(payload)) for _ in range(nputs)]

        await asyncio.gather(*futs)
        await mc.flush()

        assert not mc.shards
        assert not mc.sizes

    assert not mc.shards
    assert not mc.sizes
    assert len(d) == 10
    assert sum(map(len, d[0])) == len(gen_bytes(frac)) * nputs


@gen_test()
async def test_concurrent_puts_error():
    d = defaultdict(list)

    counter = 0

    async def send(address, shards):
        nonlocal counter
        counter += 1
        if counter == 5:
            raise OSError("error during send")
        d[address].extend(shards)

    frac = 0.1
    nshards = 10
    nputs = 20
    payload = {x: [gen_bytes(frac)] for x in range(nshards)}

    async with CommShardsBuffer(send=send) as mc:
        futs = [asyncio.create_task(mc.write(payload)) for _ in range(nputs)]

        await asyncio.gather(*futs)
        await mc.flush()
        with pytest.raises(OSError, match="error during send"):
            mc.raise_on_exception()

    assert not mc.shards
    assert not mc.sizes