1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
|
from __future__ import annotations
import pytest
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils_test import gen_test
def deserialize_bytes(buffer: bytes) -> bytes:
return buffer
@gen_test()
async def test_basic():
async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf:
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})
await mf.write({"x": b"0" * 1000, "y": b"1" * 500})
await mf.flush()
x = mf.read("x")
y = mf.read("y")
with pytest.raises(DataUnavailable):
mf.read("z")
assert x == [b"0" * 1000] * 2
assert y == [b"1" * 500] * 2
@gen_test()
async def test_read_before_flush():
payload = {"1": b"foo"}
async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf:
with pytest.raises(RuntimeError):
mf.read("1")
await mf.write(payload)
with pytest.raises(RuntimeError):
mf.read("1")
await mf.flush()
assert mf.read("1") == [b"foo"]
with pytest.raises(DataUnavailable):
mf.read("2")
@pytest.mark.parametrize("count", [2, 100, 1000])
@gen_test()
async def test_many(count):
async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf:
d = {str(i): str(i).encode() * 100 for i in range(count)}
for _ in range(10):
await mf.write(d)
await mf.flush()
for i in d:
out = mf.read(str(i))
assert out == [str(i).encode() * 100] * 10
|