File: test_disk_buffer.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (128 lines) | stat: -rw-r--r-- 3,354 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import annotations

import asyncio
import os

import pytest

from distributed.shuffle._disk import DiskShardsBuffer
from distributed.utils_test import gen_test


def dump(data, f):
    f.write(data)


def load(f):
    out = f.read()
    if not out:
        raise EOFError()
    return out


@gen_test()
async def test_basic(tmp_path):
    async with DiskShardsBuffer(directory=tmp_path, dump=dump, load=load) as mf:
        await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
        await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

        await mf.flush()

        x = mf.read("x")
        y = mf.read("y")

        with pytest.raises(KeyError):
            mf.read("z")

        assert x == b"0" * 2000
        assert y == b"1" * 1000

    assert not os.path.exists(tmp_path)


@gen_test()
async def test_read_before_flush(tmp_path):
    payload = {"1": [b"foo"]}
    async with DiskShardsBuffer(directory=tmp_path, dump=dump, load=load) as mf:
        with pytest.raises(RuntimeError):
            mf.read(1)

        await mf.write(payload)

        with pytest.raises(RuntimeError):
            mf.read(1)

        await mf.flush()
        assert mf.read("1") == b"foo"
        with pytest.raises(KeyError):
            mf.read(2)


@pytest.mark.parametrize("count", [2, 100, 1000])
@gen_test()
async def test_many(tmp_path, count):
    async with DiskShardsBuffer(directory=tmp_path, dump=dump, load=load) as mf:
        d = {i: [str(i).encode() * 100] for i in range(count)}

        for _ in range(10):
            await mf.write(d)

        await mf.flush()

        for i in d:
            out = mf.read(i)
            assert out == str(i).encode() * 100 * 10

    assert not os.path.exists(tmp_path)


@gen_test()
async def test_exceptions(tmp_path):
    def dump(data, f):
        raise Exception(123)

    async with DiskShardsBuffer(directory=tmp_path, dump=dump, load=load) as mf:
        await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

        while not mf._exception:
            await asyncio.sleep(0.1)

        with pytest.raises(Exception, match="123"):
            await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})

        await mf.flush()


@gen_test()
async def test_high_pressure_flush_with_exception(tmp_path):
    counter = 0
    payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)}

    def dump_broken(data, f):
        nonlocal counter
        # We only want to raise if this was queued up before
        if counter > DiskShardsBuffer.concurrency_limit:
            raise Exception(123)
        counter += 1
        dump(data, f)

    async with DiskShardsBuffer(
        directory=tmp_path,
        dump=dump_broken,
        load=load,
    ) as mf:
        tasks = []
        for _ in range(10):
            tasks.append(asyncio.create_task(mf.write(payload)))

        # Wait until things are actually queued up.
        # This is when there is no slot on the queue available anymore
        # but there are still shards around
        while not mf.shards:
            # Disks are fast, don't give it time to unload the queue...
            # There may only be a few ticks atm so keep this at zero
            await asyncio.sleep(0)

        with pytest.raises(Exception, match="123"):
            await mf.flush()
            mf.raise_on_exception()