File: test_comm_buffer.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (175 lines) | stat: -rw-r--r-- 4,385 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from __future__ import annotations

import asyncio
import math
from collections import defaultdict

import pytest

from dask.utils import parse_bytes

from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test


@gen_test()
async def test_basic(tmp_path):
    d = defaultdict(list)

    async def send(address, shards):
        d[address].extend(shards)

    mc = CommShardsBuffer(
        send=send,
        max_message_size=parse_bytes("2 MiB"),
        memory_limiter=ResourceLimiter(None),
        concurrency_limit=10,
    )
    await mc.write({"x": b"0" * 1000, "y": b"1" * 500})
    await mc.write({"x": b"0" * 1000, "y": b"1" * 500})

    await mc.flush()

    assert b"".join(d["x"]) == b"0" * 2000
    assert b"".join(d["y"]) == b"1" * 1000


@gen_test()
async def test_exceptions(tmp_path):
    d = defaultdict(list)

    async def send(address, shards):
        raise Exception(123)

    mc = CommShardsBuffer(
        send=send,
        max_message_size=parse_bytes("2 MiB"),
        memory_limiter=ResourceLimiter(None),
        concurrency_limit=10,
    )
    await mc.write({"x": b"0" * 1000, "y": b"1" * 500})

    while not mc._exception:
        await asyncio.sleep(0.1)

    with pytest.raises(Exception, match="123"):
        await mc.write({"x": b"0" * 1000, "y": b"1" * 500})

    await mc.flush()

    await mc.close()


@gen_test()
async def test_slow_send(tmp_path):
    block_send = asyncio.Event()
    block_send.set()
    sending_first = asyncio.Event()
    d = defaultdict(list)

    async def send(address, shards):
        await block_send.wait()
        d[address].extend(shards)
        sending_first.set()
        return {"status": "OK"}

    mc = CommShardsBuffer(
        send=send,
        max_message_size=parse_bytes("2 MiB"),
        concurrency_limit=1,
        memory_limiter=ResourceLimiter(None),
    )
    await mc.write({"x": b"0", "y": b"1"})
    await mc.write({"x": b"0", "y": b"1"})
    flush_task = asyncio.create_task(mc.flush())
    await sending_first.wait()
    block_send.clear()

    with pytest.raises(RuntimeError):
        await mc.write({"x": [b"2"], "y": [b"2"]})
        await flush_task

    assert [b"2" not in shard for shard in d["x"]]


def gen_bytes(percentage: float, memory_limit: int) -> bytes:
    num_bytes = int(math.floor(percentage * memory_limit))
    return b"0" * num_bytes


@gen_test()
async def test_concurrent_puts():
    d = defaultdict(list)

    async def send(address, shards):
        d[address].extend(shards)

    frac = 0.1
    nshards = 10
    nputs = 20
    comm_buffer = CommShardsBuffer(
        send=send,
        max_message_size=parse_bytes("2 MiB"),
        memory_limiter=ResourceLimiter(parse_bytes("100 MiB")),
        concurrency_limit=10,
    )
    payload = {
        x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards)
    }

    async with comm_buffer as mc:
        futs = [asyncio.create_task(mc.write(payload)) for _ in range(nputs)]

        await asyncio.gather(*futs)
        await mc.flush()

        assert not mc.shards
        assert not mc.sizes

    assert not mc.shards
    assert not mc.sizes
    assert len(d) == 10
    assert (
        sum(map(len, d[0]))
        == len(gen_bytes(frac, comm_buffer.memory_limiter.limit)) * nputs
    )


@gen_test()
async def test_concurrent_puts_error():
    d = defaultdict(list)

    counter = 0

    async def send(address, shards):
        nonlocal counter
        counter += 1
        if counter == 5:
            raise OSError("error during send")
        d[address].extend(shards)
        return {"status": "OK"}

    frac = 0.1
    nshards = 10
    nputs = 20
    comm_buffer = CommShardsBuffer(
        send=send,
        max_message_size=parse_bytes("2 MiB"),
        memory_limiter=ResourceLimiter(parse_bytes("100 MiB")),
        concurrency_limit=10,
    )
    payload = {
        x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards)
    }

    async with comm_buffer as mc:
        futs = [asyncio.create_task(mc.write(payload)) for _ in range(nputs)]

        await asyncio.gather(*futs)
        await mc.flush()
        with pytest.raises(OSError, match="error during send"):
            mc.raise_on_exception()

    assert not mc.shards
    assert not mc.sizes