1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
from __future__ import annotations
import asyncio
import math
from collections import defaultdict
import pytest
from dask.utils import parse_bytes
from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test
def gen_bytes(percentage: float, limit: int) -> bytes:
num_bytes = int(math.floor(percentage * limit))
return b"0" * num_bytes
class BufferTest(ShardsBuffer):
def __init__(self, memory_limiter: ResourceLimiter, concurrency_limit: int) -> None:
self.allow_process = asyncio.Event()
self.storage: dict[str, bytes] = defaultdict(bytes)
super().__init__(
memory_limiter=memory_limiter, concurrency_limit=concurrency_limit
)
async def _process(self, id: str, shards: list[bytes]) -> None:
await self.allow_process.wait()
self.storage[id] += b"".join(shards)
def read(self, id: str) -> bytes:
return self.storage[id]
limit = parse_bytes("10.0 MiB")
@pytest.mark.parametrize(
"big_payload",
[
{"big": [gen_bytes(2, limit)]},
{"big": [gen_bytes(0.5, limit)] * 4},
{f"big-{ix}": [gen_bytes(0.5, limit)] for ix in range(4)},
{f"big-{ix}": [gen_bytes(0.5, limit)] * 2 for ix in range(2)},
],
)
@gen_test()
async def test_memory_limit(big_payload):
small_payload = {"small": [gen_bytes(0.1, limit)]}
limiter = ResourceLimiter(limit)
async with BufferTest(
memory_limiter=limiter,
concurrency_limit=2,
) as buf:
# It's OK to write nothing
await buf.write({})
many_small = [asyncio.create_task(buf.write(small_payload)) for _ in range(9)]
buf.allow_process.set()
many_small = asyncio.gather(*many_small)
# Puts that do not breach the limit do not block
await many_small
assert buf.memory_limiter.time_blocked_total == 0
buf.allow_process.clear()
many_small = [asyncio.create_task(buf.write(small_payload)) for _ in range(11)]
assert buf.memory_limiter
while buf.memory_limiter.available():
await asyncio.sleep(0.1)
new_put = asyncio.create_task(buf.write(small_payload))
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(new_put), 0.1)
buf.allow_process.set()
many_small = asyncio.gather(*many_small)
await new_put
while not buf.memory_limiter.free():
await asyncio.sleep(0.1)
buf.allow_process.clear()
big = asyncio.create_task(buf.write(big_payload))
small = asyncio.create_task(buf.write(small_payload))
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(big), 0.1)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(small), 0.1)
# Puts only return once we're below memory limit
buf.allow_process.set()
await big
await small
# Once the big write is through, we can write without blocking again
before = buf.memory_limiter.time_blocked_total
await buf.write(small_payload)
assert before == buf.memory_limiter.time_blocked_total
class BufferShardsBroken(ShardsBuffer):
def __init__(self, memory_limiter: ResourceLimiter, concurrency_limit: int) -> None:
self.storage: dict[str, bytes] = defaultdict(bytes)
super().__init__(
memory_limiter=memory_limiter, concurrency_limit=concurrency_limit
)
async def _process(self, id: str, shards: list[bytes]) -> None:
if id == "error":
raise RuntimeError("Error during processing")
self.storage[id] += b"".join(shards)
def read(self, id: str) -> bytes:
return self.storage[id]
@gen_test()
async def test_memory_limit_blocked_exception():
limit = parse_bytes("10.0 MiB")
big_payload = {
"shard-1": [gen_bytes(2, limit)],
}
broken_payload = {
"error": ["not-bytes"],
}
limiter = ResourceLimiter(limit)
async with BufferShardsBroken(
memory_limiter=limiter,
concurrency_limit=2,
) as mf:
big_write = asyncio.create_task(mf.write(big_payload))
small_write = asyncio.create_task(mf.write(broken_payload))
# The broken write hits the limit and blocks
await big_write
await small_write
await mf.flush()
# Make sure exception is not dropped
with pytest.raises(RuntimeError, match="Error during processing"):
mf.raise_on_exception()
|