File: utils.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (202 lines) | stat: -rw-r--r-- 6,102 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from __future__ import annotations

import ctypes
import struct
from collections.abc import Sequence

import dask

from distributed.utils import nbytes

BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard"))


msgpack_opts = {
    ("max_%s_len" % x): 2**31 - 1 for x in ["str", "bin", "array", "map", "ext"]
}
msgpack_opts["strict_map_key"] = False
msgpack_opts["raw"] = False


def frame_split_size(
    frame: bytes | memoryview, n: int = BIG_BYTES_SHARD_SIZE
) -> list[memoryview]:
    """
    Split a frame into a list of frames of maximum size

    This helps us to avoid passing around very large bytestrings.

    Examples
    --------
    >>> frame_split_size([b'12345', b'678'], n=3)  # doctest: +SKIP
    [b'123', b'45', b'678']
    """
    n = n or BIG_BYTES_SHARD_SIZE
    frame = memoryview(frame)

    if frame.nbytes <= n:
        return [frame]

    nitems = frame.nbytes // frame.itemsize
    items_per_shard = n // frame.itemsize

    return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)]


def pack_frames_prelude(frames):
    nframes = len(frames)
    nbytes_frames = map(nbytes, frames)
    return struct.pack(f"Q{nframes}Q", nframes, *nbytes_frames)


def pack_frames(frames):
    """Pack frames into a byte-like object

    This prepends length information to the front of the bytes-like object

    See Also
    --------
    unpack_frames
    """
    return b"".join([pack_frames_prelude(frames), *frames])


def unpack_frames(b):
    """Unpack bytes into a sequence of frames

    This assumes that length information is at the front of the bytestring,
    as performed by pack_frames

    See Also
    --------
    pack_frames
    """
    b = memoryview(b)

    fmt = "Q"
    fmt_size = struct.calcsize(fmt)

    (n_frames,) = struct.unpack_from(fmt, b)
    lengths = struct.unpack_from(f"{n_frames}{fmt}", b, fmt_size)

    frames = []
    start = fmt_size * (1 + n_frames)
    for length in lengths:
        end = start + length
        frames.append(b[start:end])
        start = end

    return frames


def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview:
    """
    Zero-copy "concatenate" a sequence of contiguous memoryviews.

    Returns a new memoryview which slices into the underlying buffer
    to extract out the portion equivalent to all of ``mvs`` being concatenated.

    All the memoryviews must:
    * Share the same underlying buffer (``.obj``)
    * When merged, cover a continuous portion of that buffer with no gaps
    * Have the same strides
    * Be 1-dimensional
    * Have the same format
    * Be contiguous

    Raises ValueError if these conditions are not met.
    """
    if not mvs:
        return memoryview(bytearray())
    if len(mvs) == 1:
        return mvs[0]

    first = mvs[0]
    if not isinstance(first, memoryview):
        raise TypeError(f"Expected memoryview; got {type(first)}")
    obj = first.obj
    format = first.format

    first_start_addr = 0
    nbytes = 0
    for i, mv in enumerate(mvs):
        if not isinstance(mv, memoryview):
            raise TypeError(f"{i}: expected memoryview; got {type(mv)}")

        if mv.nbytes == 0:
            continue

        if mv.obj is not obj:
            raise ValueError(
                f"{i}: memoryview has different buffer: {mv.obj!r} vs {obj!r}"
            )
        if not mv.contiguous:
            raise ValueError(f"{i}: memoryview non-contiguous")
        if mv.ndim != 1:
            raise ValueError(f"{i}: memoryview has {mv.ndim} dimensions, not 1")
        if mv.format != format:
            raise ValueError(f"{i}: inconsistent format: {mv.format} vs {format}")

        start_addr = address_of_memoryview(mv)
        if first_start_addr == 0:
            first_start_addr = start_addr
        else:
            expected_addr = first_start_addr + nbytes
            if start_addr != expected_addr:
                raise ValueError(
                    f"memoryview {i} does not start where the previous ends. "
                    f"Expected {expected_addr:x}, starts {start_addr - expected_addr} byte(s) away."
                )
        nbytes += mv.nbytes

    if nbytes == 0:
        # all memoryviews were zero-length
        assert len(first) == 0
        return first

    assert first_start_addr != 0, "Underlying buffer is null pointer?!"

    base_mv = memoryview(obj).cast("B")
    base_start_addr = address_of_memoryview(base_mv)
    start_index = first_start_addr - base_start_addr

    return base_mv[start_index : start_index + nbytes].cast(format)


one_byte_carr = ctypes.c_byte * 1
# ^ length and type don't matter, just use it to get the address of the first byte


def address_of_memoryview(mv: memoryview) -> int:
    """
    Get the pointer to the first byte of a memoryview's data.

    If the memoryview is read-only, NumPy must be installed.
    """
    # NOTE: this method relies on pointer arithmetic to figure out
    # where each memoryview starts within the underlying buffer.
    # There's no direct API to get the address of a memoryview,
    # so we use a trick through ctypes and the buffer protocol:
    # https://mattgwwalker.wordpress.com/2020/10/15/address-of-a-buffer-in-python/
    try:
        carr = one_byte_carr.from_buffer(mv)
    except TypeError:
        # `mv` is read-only. `from_buffer` requires the buffer to be writeable.
        # See https://bugs.python.org/issue11427 for discussion.
        # This typically comes from `deserialize_bytes`, where `mv.obj` is an
        # immutable bytestring.
        pass
    else:
        return ctypes.addressof(carr)

    try:
        import numpy as np
    except ImportError:
        raise ValueError(
            f"Cannot get address of read-only memoryview {mv} since NumPy is not installed."
        )

    # NumPy doesn't mind read-only buffers. We could just use this method
    # for all cases, but it's nice to use the pure-Python method for the common
    # case of writeable buffers (created by TCP comms, for example).
    return np.asarray(mv).__array_interface__["data"][0]