File: test_disk_buffer.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (143 lines) | stat: -rw-r--r-- 4,131 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import annotations

import asyncio
import os
from pathlib import Path
from typing import Any

import pytest

from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test


def read_bytes(path: Path) -> tuple[bytes, int]:
    with path.open("rb") as f:
        data = f.read()
        size = f.tell()
    return data, size


@gen_test()
async def test_basic(tmp_path):
    async with DiskShardsBuffer(
        directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
    ) as mf:
        await mf.write({"x": b"0" * 1000, "y": b"1" * 500})
        await mf.write({"x": b"0" * 1000, "y": b"1" * 500})

        await mf.flush()

        x = mf.read("x")
        y = mf.read("y")

        with pytest.raises(DataUnavailable):
            mf.read("z")

        assert x == b"0" * 2000
        assert y == b"1" * 1000

    assert not os.path.exists(tmp_path)


@gen_test()
async def test_read_before_flush(tmp_path):
    payload = {"1": b"foo"}
    async with DiskShardsBuffer(
        directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
    ) as mf:
        with pytest.raises(RuntimeError):
            mf.read(1)

        await mf.write(payload)

        with pytest.raises(RuntimeError):
            mf.read(1)

        await mf.flush()
        assert mf.read("1") == b"foo"
        with pytest.raises(DataUnavailable):
            mf.read(2)


@pytest.mark.parametrize("count", [2, 100, 1000])
@gen_test()
async def test_many(tmp_path, count):
    async with DiskShardsBuffer(
        directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
    ) as mf:
        d = {i: str(i).encode() * 100 for i in range(count)}

        for _ in range(10):
            await mf.write(d)

        await mf.flush()

        for i in d:
            out = mf.read(i)
            assert out == str(i).encode() * 100 * 10

    assert not os.path.exists(tmp_path)


class BrokenDiskShardsBuffer(DiskShardsBuffer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    async def _process(self, *args: Any, **kwargs: Any) -> None:
        raise Exception(123)


@gen_test()
async def test_exceptions(tmp_path):
    async with BrokenDiskShardsBuffer(
        directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
    ) as mf:
        await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

        while not mf._exception:
            await asyncio.sleep(0.1)

        with pytest.raises(Exception, match="123"):
            await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

        await mf.flush()


class EventuallyBrokenDiskShardsBuffer(DiskShardsBuffer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.counter = 0

    async def _process(self, *args: Any, **kwargs: Any) -> None:
        # We only want to raise if this was queued up before
        if self.counter > self.concurrency_limit:
            raise Exception(123)
        self.counter += 1
        return await super()._process(*args, **kwargs)


@gen_test()
async def test_high_pressure_flush_with_exception(tmp_path):
    payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)}

    async with EventuallyBrokenDiskShardsBuffer(
        directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
    ) as mf:
        tasks = []
        for _ in range(10):
            tasks.append(asyncio.create_task(mf.write(payload)))

        # Wait until things are actually queued up.
        # This is when there is no slot on the queue available anymore
        # but there are still shards around
        while not mf.shards:
            # Disks are fast, don't give it time to unload the queue...
            # There may only be a few ticks atm so keep this at zero
            await asyncio.sleep(0)

        with pytest.raises(Exception, match="123"):
            await mf.flush()
            mf.raise_on_exception()