1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
from __future__ import annotations
import pytest
from distributed.protocol.utils import (
merge_memoryviews,
pack_frames,
pack_frames_prelude,
unpack_frames,
)
def test_pack_frames():
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)
frames2 = unpack_frames(b)
assert frames2 == frames
@pytest.mark.parametrize("extra", [b"456", b""])
def test_unpack_frames_remainder(extra):
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)
frames2 = unpack_frames(b + extra)
assert frames2 == frames
frames2 = unpack_frames(b + extra, remainder=True)
assert isinstance(frames2[-1], memoryview)
assert frames2 == frames + [extra]
def test_unpack_frames_partial():
frames = [b"123", b"asdf"]
frames.insert(0, pack_frames_prelude(frames))
frames2, missing_lenghts = unpack_frames(b"".join(frames), partial=True)
assert frames2 == frames[1:]
assert missing_lenghts == []
frames2, missing_lenghts = unpack_frames(b"".join(frames[:-1]), partial=True)
assert frames2 == frames[1:-1]
assert missing_lenghts == [4]
frames2, missing_lenghts = unpack_frames(frames[0], partial=True)
assert frames2 == []
assert missing_lenghts == [3, 4]
with pytest.raises(AssertionError):
unpack_frames(b"".join(frames[:-1]))
class TestMergeMemroyviews:
def test_empty(self):
empty = merge_memoryviews([])
assert isinstance(empty, memoryview) and len(empty) == 0
def test_one(self):
base = bytearray(range(10))
base_mv = memoryview(base)
assert merge_memoryviews([base_mv]) is base_mv
@pytest.mark.parametrize(
"slices",
[
[slice(None, 3), slice(3, None)],
[slice(1, 3), slice(3, None)],
[slice(1, 3), slice(3, -1)],
[slice(0, 0), slice(None)],
[slice(None), slice(-1, -1)],
[slice(0, 0), slice(0, 0)],
[slice(None, 3), slice(3, 7), slice(7, None)],
[slice(2, 3), slice(3, 7), slice(7, 9)],
[slice(2, 3), slice(3, 7), slice(7, 9), slice(9, 9)],
[slice(1, 2), slice(2, 5), slice(5, 8), slice(8, None)],
],
)
def test_parts(self, slices):
base = bytearray(range(10))
base_mv = memoryview(base)
equiv_start = min(s.indices(10)[0] for s in slices)
equiv_stop = max(s.indices(10)[1] for s in slices)
equiv = base_mv[equiv_start:equiv_stop]
parts = [base_mv[s] for s in slices]
result = merge_memoryviews(parts)
assert result.obj is base
assert len(result) == len(equiv)
assert result == equiv
def test_readonly_buffer(self):
pytest.importorskip(
"numpy", reason="Read-only buffer zero-copy merging requires NumPy"
)
base = bytes(range(10))
base_mv = memoryview(base)
result = merge_memoryviews([base_mv[:4], base_mv[4:]])
assert result.obj is base
assert len(result) == len(base)
assert result == base
def test_catch_non_memoryview(self):
pytest.importorskip("numpy")
with pytest.raises(TypeError, match="Expected memoryview"):
merge_memoryviews([b"1234", memoryview(b"4567")])
with pytest.raises(TypeError, match="expected memoryview"):
merge_memoryviews([memoryview(b"123"), b"1234"])
@pytest.mark.parametrize(
"slices",
[
[slice(None, 3), slice(4, None)],
[slice(None, 3), slice(2, None)],
[slice(1, 3), slice(3, 6), slice(9, None)],
],
)
def test_catch_gaps(self, slices):
base = bytearray(range(10))
base_mv = memoryview(base)
parts = [base_mv[s] for s in slices]
with pytest.raises(ValueError, match="does not start where the previous ends"):
merge_memoryviews(parts)
def test_catch_different_buffer(self):
base = bytearray(range(8))
base_mv = memoryview(base)
with pytest.raises(ValueError, match="different buffer"):
merge_memoryviews([base_mv, memoryview(base.copy())])
def test_catch_different_non_contiguous(self):
base = bytearray(range(8))
base_mv = memoryview(base)[::-1]
with pytest.raises(ValueError, match="non-contiguous"):
merge_memoryviews([base_mv[:3], base_mv[3:]])
def test_catch_multidimensional(self):
base = bytearray(range(6))
base_mv = memoryview(base).cast("B", [3, 2])
with pytest.raises(ValueError, match="has 2 dimensions, not 1"):
merge_memoryviews([base_mv[:1], base_mv[1:]])
def test_catch_different_formats(self):
base = bytearray(range(8))
base_mv = memoryview(base)
with pytest.raises(ValueError, match="inconsistent format: I vs B"):
merge_memoryviews([base_mv[:4], base_mv[4:].cast("I")])
|