File: test_protocol_utils.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (152 lines) | stat: -rw-r--r-- 4,987 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations

import pytest

from distributed.protocol.utils import (
    merge_memoryviews,
    pack_frames,
    pack_frames_prelude,
    unpack_frames,
)


def test_pack_frames():
    frames = [b"123", b"asdf"]
    b = pack_frames(frames)
    assert isinstance(b, bytes)
    frames2 = unpack_frames(b)
    assert frames2 == frames


@pytest.mark.parametrize("extra", [b"456", b""])
def test_unpack_frames_remainder(extra):
    frames = [b"123", b"asdf"]
    b = pack_frames(frames)
    assert isinstance(b, bytes)

    frames2 = unpack_frames(b + extra)
    assert frames2 == frames

    frames2 = unpack_frames(b + extra, remainder=True)
    assert isinstance(frames2[-1], memoryview)
    assert frames2 == frames + [extra]


def test_unpack_frames_partial():
    frames = [b"123", b"asdf"]
    frames.insert(0, pack_frames_prelude(frames))

    frames2, missing_lenghts = unpack_frames(b"".join(frames), partial=True)
    assert frames2 == frames[1:]
    assert missing_lenghts == []

    frames2, missing_lenghts = unpack_frames(b"".join(frames[:-1]), partial=True)
    assert frames2 == frames[1:-1]
    assert missing_lenghts == [4]

    frames2, missing_lenghts = unpack_frames(frames[0], partial=True)
    assert frames2 == []
    assert missing_lenghts == [3, 4]

    with pytest.raises(AssertionError):
        unpack_frames(b"".join(frames[:-1]))


class TestMergeMemroyviews:
    def test_empty(self):
        empty = merge_memoryviews([])
        assert isinstance(empty, memoryview) and len(empty) == 0

    def test_one(self):
        base = bytearray(range(10))
        base_mv = memoryview(base)
        assert merge_memoryviews([base_mv]) is base_mv

    @pytest.mark.parametrize(
        "slices",
        [
            [slice(None, 3), slice(3, None)],
            [slice(1, 3), slice(3, None)],
            [slice(1, 3), slice(3, -1)],
            [slice(0, 0), slice(None)],
            [slice(None), slice(-1, -1)],
            [slice(0, 0), slice(0, 0)],
            [slice(None, 3), slice(3, 7), slice(7, None)],
            [slice(2, 3), slice(3, 7), slice(7, 9)],
            [slice(2, 3), slice(3, 7), slice(7, 9), slice(9, 9)],
            [slice(1, 2), slice(2, 5), slice(5, 8), slice(8, None)],
        ],
    )
    def test_parts(self, slices):
        base = bytearray(range(10))
        base_mv = memoryview(base)

        equiv_start = min(s.indices(10)[0] for s in slices)
        equiv_stop = max(s.indices(10)[1] for s in slices)
        equiv = base_mv[equiv_start:equiv_stop]

        parts = [base_mv[s] for s in slices]
        result = merge_memoryviews(parts)
        assert result.obj is base
        assert len(result) == len(equiv)
        assert result == equiv

    def test_readonly_buffer(self):
        pytest.importorskip(
            "numpy", reason="Read-only buffer zero-copy merging requires NumPy"
        )
        base = bytes(range(10))
        base_mv = memoryview(base)

        result = merge_memoryviews([base_mv[:4], base_mv[4:]])
        assert result.obj is base
        assert len(result) == len(base)
        assert result == base

    def test_catch_non_memoryview(self):
        pytest.importorskip("numpy")
        with pytest.raises(TypeError, match="Expected memoryview"):
            merge_memoryviews([b"1234", memoryview(b"4567")])

        with pytest.raises(TypeError, match="expected memoryview"):
            merge_memoryviews([memoryview(b"123"), b"1234"])

    @pytest.mark.parametrize(
        "slices",
        [
            [slice(None, 3), slice(4, None)],
            [slice(None, 3), slice(2, None)],
            [slice(1, 3), slice(3, 6), slice(9, None)],
        ],
    )
    def test_catch_gaps(self, slices):
        base = bytearray(range(10))
        base_mv = memoryview(base)

        parts = [base_mv[s] for s in slices]
        with pytest.raises(ValueError, match="does not start where the previous ends"):
            merge_memoryviews(parts)

    def test_catch_different_buffer(self):
        base = bytearray(range(8))
        base_mv = memoryview(base)
        with pytest.raises(ValueError, match="different buffer"):
            merge_memoryviews([base_mv, memoryview(base.copy())])

    def test_catch_different_non_contiguous(self):
        base = bytearray(range(8))
        base_mv = memoryview(base)[::-1]
        with pytest.raises(ValueError, match="non-contiguous"):
            merge_memoryviews([base_mv[:3], base_mv[3:]])

    def test_catch_multidimensional(self):
        base = bytearray(range(6))
        base_mv = memoryview(base).cast("B", [3, 2])
        with pytest.raises(ValueError, match="has 2 dimensions, not 1"):
            merge_memoryviews([base_mv[:1], base_mv[1:]])

    def test_catch_different_formats(self):
        base = bytearray(range(8))
        base_mv = memoryview(base)
        with pytest.raises(ValueError, match="inconsistent format: I vs B"):
            merge_memoryviews([base_mv[:4], base_mv[4:].cast("I")])