1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from typing import Any
import pytest
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test
def read_bytes(path: Path) -> tuple[bytes, int]:
with path.open("rb") as f:
data = f.read()
size = f.tell()
return data, size
@gen_test()
async def test_basic(tmp_path):
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})
await mf.flush()
x = mf.read("x")
y = mf.read("y")
with pytest.raises(DataUnavailable):
mf.read("z")
assert x == b"0" * 2000
assert y == b"1" * 1000
assert not os.path.exists(tmp_path)
@gen_test()
async def test_read_before_flush(tmp_path):
payload = {"1": b"foo"}
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
with pytest.raises(RuntimeError):
mf.read(1)
await mf.write(payload)
with pytest.raises(RuntimeError):
mf.read(1)
await mf.flush()
assert mf.read("1") == b"foo"
with pytest.raises(DataUnavailable):
mf.read(2)
@pytest.mark.parametrize("count", [2, 100, 1000])
@gen_test()
async def test_many(tmp_path, count):
async with DiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
d = {i: str(i).encode() * 100 for i in range(count)}
for _ in range(10):
await mf.write(d)
await mf.flush()
for i in d:
out = mf.read(i)
assert out == str(i).encode() * 100 * 10
assert not os.path.exists(tmp_path)
class BrokenDiskShardsBuffer(DiskShardsBuffer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def _process(self, *args: Any, **kwargs: Any) -> None:
raise Exception(123)
@gen_test()
async def test_exceptions(tmp_path):
async with BrokenDiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
while not mf._exception:
await asyncio.sleep(0.1)
with pytest.raises(Exception, match="123"):
await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]})
await mf.flush()
class EventuallyBrokenDiskShardsBuffer(DiskShardsBuffer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0
async def _process(self, *args: Any, **kwargs: Any) -> None:
# We only want to raise if this was queued up before
if self.counter > self.concurrency_limit:
raise Exception(123)
self.counter += 1
return await super()._process(*args, **kwargs)
@gen_test()
async def test_high_pressure_flush_with_exception(tmp_path):
payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)}
async with EventuallyBrokenDiskShardsBuffer(
directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None)
) as mf:
tasks = []
for _ in range(10):
tasks.append(asyncio.create_task(mf.write(payload)))
# Wait until things are actually queued up.
# This is when there is no slot on the queue available anymore
# but there are still shards around
while not mf.shards:
# Disks are fast, don't give it time to unload the queue...
# There may only be a few ticks atm so keep this at zero
await asyncio.sleep(0)
with pytest.raises(Exception, match="123"):
await mf.flush()
mf.raise_on_exception()
|