1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
|
from __future__ import annotations
import logging
import msgpack
import dask.config
from distributed.protocol import pickle
from distributed.protocol.compression import decompress, maybe_compress
from distributed.protocol.serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
merge_and_deserialize,
msgpack_decode_default,
msgpack_encode_default,
serialize_and_split,
)
from distributed.protocol.utils import msgpack_opts
from distributed.utils import ensure_memoryview
logger = logging.getLogger(__name__)
def dumps( # type: ignore[no-untyped-def]
msg, serializers=None, on_error="message", context=None, frame_split_size=None
) -> list:
"""Transform Python message to bytestream suitable for communication
Developer Notes
---------------
The approach here is to use `msgpack.dumps()` to serialize `msg` and
write the result to the first output frame. If `msgpack.dumps()`
encounters an object it cannot serialize like a NumPy array, it is handled
out-of-band by `_encode_default()` and appended to the output frame list.
"""
try:
if context and "compression" in context:
compress_opts = {"compression": context["compression"]}
else:
compress_opts = {}
def _inplace_compress_frames(header, frames):
compression = list(header.get("compression", [None] * len(frames)))
for i in range(len(frames)):
if compression[i] is None:
compression[i], frames[i] = maybe_compress(
frames[i], **compress_opts
)
header["compression"] = tuple(compression)
def create_serialized_sub_frames(obj: Serialized | Serialize) -> list:
if isinstance(obj, Serialized):
sub_header, sub_frames = obj.header, obj.frames
else:
sub_header, sub_frames = serialize_and_split(
obj,
serializers=serializers,
on_error=on_error,
context=context,
size=frame_split_size,
)
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
sub_header = msgpack.dumps(
sub_header, default=msgpack_encode_default, use_bin_type=True
)
return [sub_header] + sub_frames
def create_pickled_sub_frames(obj: Pickled | ToPickle) -> list:
if isinstance(obj, Pickled):
sub_header, sub_frames = obj.header, obj.frames
else:
sub_frames = []
sub_header = {
"pickled-obj": pickle.dumps(
obj.data,
# In to support len() and slicing, we convert `PickleBuffer`
# objects to memoryviews of bytes.
buffer_callback=lambda x: sub_frames.append(
ensure_memoryview(x)
),
)
}
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
sub_header = msgpack.dumps(sub_header)
return [sub_header] + sub_frames
frames = [None]
def _encode_default(obj):
if isinstance(obj, (Serialize, Serialized)):
offset = len(frames)
frames.extend(create_serialized_sub_frames(obj))
return {"__Serialized__": offset}
elif isinstance(obj, (ToPickle, Pickled)):
offset = len(frames)
frames.extend(create_pickled_sub_frames(obj))
return {"__Pickled__": offset}
else:
return msgpack_encode_default(obj)
frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True)
return frames
except Exception:
logger.critical("Failed to Serialize", exc_info=True)
raise
def loads(frames, deserialize=True, deserializers=None):
"""Transform bytestream back into Python value"""
allow_pickle = dask.config.get("distributed.scheduler.pickle")
try:
def _decode_default(obj):
offset = obj.get("__Serialized__", 0)
if offset > 0:
sub_header = msgpack.loads(
frames[offset],
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts,
)
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
else:
return Serialized(sub_header, sub_frames)
offset = obj.get("__Pickled__", 0)
if offset > 0:
sub_header = msgpack.loads(frames[offset])
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if allow_pickle:
return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames)
else:
raise ValueError(
"Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`"
)
return msgpack_decode_default(obj)
return msgpack.loads(
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
)
except Exception:
logger.critical("Failed to deserialize", exc_info=True)
raise
|