File: test_buffer.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (142 lines) | stat: -rw-r--r-- 4,664 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import asyncio
import math
from collections import defaultdict

import pytest

from dask.utils import parse_bytes

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test


def gen_bytes(percentage: float, limit: int) -> bytes:
    num_bytes = int(math.floor(percentage * limit))
    return b"0" * num_bytes


class BufferTest(ShardsBuffer):
    def __init__(self, memory_limiter: ResourceLimiter, concurrency_limit: int) -> None:
        self.allow_process = asyncio.Event()
        self.storage: dict[str, bytes] = defaultdict(bytes)
        super().__init__(
            memory_limiter=memory_limiter, concurrency_limit=concurrency_limit
        )

    async def _process(self, id: str, shards: list[bytes]) -> None:
        await self.allow_process.wait()
        self.storage[id] += b"".join(shards)

    def read(self, id: str) -> bytes:
        return self.storage[id]


limit = parse_bytes("10.0 MiB")


@pytest.mark.parametrize(
    "big_payload",
    [
        {"big": [gen_bytes(2, limit)]},
        {"big": [gen_bytes(0.5, limit)] * 4},
        {f"big-{ix}": [gen_bytes(0.5, limit)] for ix in range(4)},
        {f"big-{ix}": [gen_bytes(0.5, limit)] * 2 for ix in range(2)},
    ],
)
@gen_test()
async def test_memory_limit(big_payload):

    small_payload = {"small": [gen_bytes(0.1, limit)]}

    limiter = ResourceLimiter(limit)

    async with BufferTest(
        memory_limiter=limiter,
        concurrency_limit=2,
    ) as buf:
        # It's OK to write nothing
        await buf.write({})

        many_small = [asyncio.create_task(buf.write(small_payload)) for _ in range(9)]
        buf.allow_process.set()
        many_small = asyncio.gather(*many_small)
        # Puts that do not breach the limit do not block
        await many_small
        assert buf.memory_limiter.time_blocked_total == 0
        buf.allow_process.clear()

        many_small = [asyncio.create_task(buf.write(small_payload)) for _ in range(11)]
        assert buf.memory_limiter
        while buf.memory_limiter.available():
            await asyncio.sleep(0.1)

        new_put = asyncio.create_task(buf.write(small_payload))
        with pytest.raises(asyncio.TimeoutError):
            await asyncio.wait_for(asyncio.shield(new_put), 0.1)
        buf.allow_process.set()
        many_small = asyncio.gather(*many_small)
        await new_put

        while not buf.memory_limiter.free():
            await asyncio.sleep(0.1)
        buf.allow_process.clear()
        big = asyncio.create_task(buf.write(big_payload))
        small = asyncio.create_task(buf.write(small_payload))
        with pytest.raises(asyncio.TimeoutError):
            await asyncio.wait_for(asyncio.shield(big), 0.1)
        with pytest.raises(asyncio.TimeoutError):
            await asyncio.wait_for(asyncio.shield(small), 0.1)
        # Puts only return once we're below memory limit
        buf.allow_process.set()
        await big
        await small
        # Once the big write is through, we can write without blocking again
        before = buf.memory_limiter.time_blocked_total
        await buf.write(small_payload)
        assert before == buf.memory_limiter.time_blocked_total


class BufferShardsBroken(ShardsBuffer):
    def __init__(self, memory_limiter: ResourceLimiter, concurrency_limit: int) -> None:
        self.storage: dict[str, bytes] = defaultdict(bytes)
        super().__init__(
            memory_limiter=memory_limiter, concurrency_limit=concurrency_limit
        )

    async def _process(self, id: str, shards: list[bytes]) -> None:
        if id == "error":
            raise RuntimeError("Error during processing")
        self.storage[id] += b"".join(shards)

    def read(self, id: str) -> bytes:
        return self.storage[id]


@gen_test()
async def test_memory_limit_blocked_exception():
    limit = parse_bytes("10.0 MiB")

    big_payload = {
        "shard-1": [gen_bytes(2, limit)],
    }
    broken_payload = {
        "error": ["not-bytes"],
    }
    limiter = ResourceLimiter(limit)
    async with BufferShardsBroken(
        memory_limiter=limiter,
        concurrency_limit=2,
    ) as mf:
        big_write = asyncio.create_task(mf.write(big_payload))
        small_write = asyncio.create_task(mf.write(broken_payload))
        # The broken write hits the limit and blocks
        await big_write
        await small_write

        await mf.flush()
        # Make sure exception is not dropped
        with pytest.raises(RuntimeError, match="Error during processing"):
            mf.raise_on_exception()