1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930
|
from __future__ import annotations
import codecs
import importlib
import traceback
from array import array
from enum import Enum
from functools import partial
from types import ModuleType
from typing import Any, Literal
import msgpack
import dask
from dask.base import normalize_token
from dask.utils import typename
from distributed.protocol import pickle
from distributed.protocol.compression import decompress, maybe_compress
from distributed.protocol.utils import (
frame_split_size,
merge_memoryviews,
msgpack_opts,
pack_frames_prelude,
unpack_frames,
)
from distributed.utils import ensure_memoryview, has_keyword
dask_serialize = dask.utils.Dispatch("dask_serialize")
dask_deserialize = dask.utils.Dispatch("dask_deserialize")
_cached_allowed_modules: dict[str, ModuleType] = {}
def dask_dumps(x, context=None):
"""Serialize object using the class-based registry"""
type_name = typename(type(x))
try:
dumps = dask_serialize.dispatch(type(x))
except TypeError:
raise NotImplementedError(type_name)
if has_keyword(dumps, "context"):
sub_header, frames = dumps(x, context=context)
else:
sub_header, frames = dumps(x)
header = {
"sub-header": sub_header,
"type": type_name,
"type-serialized": pickle.dumps(type(x)),
"serializer": "dask",
}
return header, frames
def dask_loads(header, frames):
typ = pickle.loads(header["type-serialized"])
loads = dask_deserialize.dispatch(typ)
return loads(header["sub-header"], frames)
def pickle_dumps(x, context=None):
frames = [None]
writeable = []
def buffer_callback(f):
f = memoryview(f)
frames.append(f)
writeable.append(not f.readonly)
frames[0] = pickle.dumps(
x,
buffer_callback=buffer_callback,
protocol=context.get("pickle-protocol", None) if context else None,
)
header = {
"serializer": "pickle",
"writeable": tuple(writeable),
}
return header, frames
def pickle_loads(header, frames):
x, buffers = frames[0], frames[1:]
writeable = header.get("writeable")
if not writeable:
writeable = len(buffers) * (None,)
buffers = [
memoryview(bytearray(mv) if w else bytes(mv)) if w == mv.readonly else mv
for w, mv in zip(writeable, map(ensure_memoryview, buffers))
]
return pickle.loads(x, buffers=buffers)
def import_allowed_module(name):
if name in _cached_allowed_modules:
return _cached_allowed_modules[name]
# Check for non-ASCII characters
name = name.encode("ascii").decode()
# We only compare the root module
root = name.split(".", 1)[0]
# Note, if an empty string creeps into allowed-imports it is disallowed explicitly
if root and root in dask.config.get("distributed.scheduler.allowed-imports"):
_cached_allowed_modules[name] = importlib.import_module(name)
return _cached_allowed_modules[name]
else:
raise RuntimeError(
f"Importing {repr(name)} is not allowed, please add it to the list of "
"allowed modules the scheduler can import via the "
"distributed.scheduler.allowed-imports configuration setting."
)
def msgpack_decode_default(obj):
"""
Custom packer/unpacker for msgpack
"""
if "__Enum__" in obj:
mod = import_allowed_module(obj["__module__"])
typ = getattr(mod, obj["__name__"])
return getattr(typ, obj["name"])
if "__Set__" in obj:
return set(obj["as-list"])
if "__Serialized__" in obj:
# Notice, the data here is marked a Serialized rather than deserialized. This
# is because deserialization requires Pickle which the Scheduler cannot run
# because of security reasons.
# By marking it Serialized, the data is passed through to the workers that
# eventually will deserialize it.
return Serialized(*obj["data"])
return obj
def msgpack_encode_default(obj):
"""
Custom packer/unpacker for msgpack
"""
if isinstance(obj, Serialize):
return {"__Serialized__": True, "data": serialize(obj.data)}
if isinstance(obj, Enum):
return {
"__Enum__": True,
"name": obj.name,
"__module__": obj.__module__,
"__name__": type(obj).__name__,
}
if isinstance(obj, set):
return {"__Set__": True, "as-list": list(obj)}
return obj
def msgpack_dumps(x):
try:
frame = msgpack.dumps(x, use_bin_type=True)
except Exception:
raise NotImplementedError()
else:
return {"serializer": "msgpack"}, [frame]
def msgpack_loads(header, frames):
return msgpack.loads(b"".join(frames), use_list=False, **msgpack_opts)
def serialization_error_loads(header, frames):
msg = "\n".join([codecs.decode(frame, "utf8") for frame in frames])
raise TypeError(msg)
families = {}
def register_serialization_family(name, dumps, loads):
families[name] = (dumps, loads, dumps and has_keyword(dumps, "context"))
register_serialization_family("dask", dask_dumps, dask_loads)
register_serialization_family("pickle", pickle_dumps, pickle_loads)
register_serialization_family("msgpack", msgpack_dumps, msgpack_loads)
register_serialization_family("error", None, serialization_error_loads)
def check_dask_serializable(x):
if type(x) in (list, set, tuple) and len(x):
return check_dask_serializable(next(iter(x)))
elif type(x) is dict and len(x):
return check_dask_serializable(next(iter(x.items()))[1])
else:
try:
dask_serialize.dispatch(type(x))
return True
except TypeError:
pass
return False
def serialize( # type: ignore[no-untyped-def]
x: object,
serializers=None,
on_error: Literal["message" | "raise"] = "message",
context=None,
iterate_collection: bool | None = None,
) -> tuple[dict[str, Any], list[bytes | memoryview]]:
r"""
Convert object to a header and list of bytestrings
This takes in an arbitrary Python object and returns a msgpack serializable
header and a list of bytes or memoryview objects.
The serialization protocols to use are configurable: a list of names
define the set of serializers to use, in order. These names are keys in
the ``serializer_registry`` dict (e.g., 'pickle', 'msgpack'), which maps
to the de/serialize functions. The name 'dask' is special, and will use the
per-class serialization methods. ``None`` gives the default list
``['dask', 'pickle']``.
Notes on the ``iterate_collection`` argument (only relevant when
``x`` is a collection):
- ``iterate_collection=True``: Serialize collection elements separately.
- ``iterate_collection=False``: Serialize collection elements together.
- ``iterate_collection=None`` (default): Infer the best setting.
Examples
--------
>>> serialize(1)
({}, [b'\x80\x04\x95\x03\x00\x00\x00\x00\x00\x00\x00K\x01.'])
>>> serialize(b'123') # some special types get custom treatment
({'type': 'builtins.bytes'}, [b'123'])
>>> deserialize(*serialize(1))
1
Returns
-------
header: dictionary containing any msgpack-serializable metadata
frames: list of bytes or memoryviews, commonly of length one
See Also
--------
deserialize : Convert header and frames back to object
to_serialize : Mark that data in a message should be serialized
register_serialization : Register custom serialization functions
"""
if serializers is None:
serializers = ("dask", "pickle") # TODO: get from configuration
# Handle obects that are marked as `Serialize`, or that are
# already `Serialized` objects (don't want to serialize them twice)
if isinstance(x, Serialized):
return x.header, x.frames
if isinstance(x, Serialize):
return serialize(
x.data,
serializers=serializers,
on_error=on_error,
context=context,
iterate_collection=True,
)
# Note: don't use isinstance(), as it would match subclasses
# (e.g. namedtuple, defaultdict) which however would revert to the base class on a
# round-trip through msgpack
if iterate_collection is None and type(x) in (list, set, tuple, dict):
if type(x) is list and "msgpack" in serializers:
# Note: "msgpack" will always convert lists to tuples
# (see GitHub #3716), so we should iterate
# through the list if "msgpack" comes before "pickle"
# in the list of serializers.
iterate_collection = ("pickle" not in serializers) or (
serializers.index("pickle") > serializers.index("msgpack")
)
if not iterate_collection:
# Check for "dask"-serializable data in dict/list/set
iterate_collection = check_dask_serializable(x)
# Determine whether keys are safe to be serialized with msgpack
if type(x) is dict and iterate_collection:
try:
msgpack.dumps(list(x.keys()))
except Exception:
dict_safe = False
else:
dict_safe = True
if (
type(x) in (list, set, tuple)
and iterate_collection
or type(x) is dict
and iterate_collection
and dict_safe
):
if isinstance(x, dict):
headers_frames = []
for k, v in x.items():
_header, _frames = serialize(
v, serializers=serializers, on_error=on_error, context=context
)
_header["key"] = k
headers_frames.append((_header, _frames))
else:
assert isinstance(x, (list, set, tuple))
headers_frames = [
serialize(
obj, serializers=serializers, on_error=on_error, context=context
)
for obj in x
]
frames = []
lengths = []
compressions: list[str | None] = []
for _header, _frames in headers_frames:
frames.extend(_frames)
length = len(_frames)
lengths.append(length)
compressions.extend(_header.get("compression") or [None] * len(_frames))
headers = {
"sub-headers": [obj[0] for obj in headers_frames],
"is-collection": True,
"frame-lengths": lengths,
"type-serialized": type(x).__name__,
}
if any(compression is not None for compression in compressions):
headers["compression"] = compressions
return headers, frames
tb = ""
for name in serializers:
dumps, _, wants_context = families[name]
try:
header, frames = dumps(x, context=context) if wants_context else dumps(x)
header["serializer"] = name
return header, frames
except NotImplementedError:
continue
except Exception:
tb = traceback.format_exc()
break
msg = f"Could not serialize object of type {type(x).__name__}"
if on_error == "message":
txt_frames = [msg]
if tb:
txt_frames.append(tb[:100000])
frames = [frame.encode() for frame in txt_frames]
return {"serializer": "error"}, frames
elif on_error == "raise":
raise TypeError(msg, str(x)[:10000])
else: # pragma: nocover
raise ValueError(f"{on_error=}; expected 'message' or 'raise'")
def deserialize(header, frames, deserializers=None):
"""
Convert serialized header and list of bytestrings back to a Python object
Parameters
----------
header : dict
frames : list of bytes
deserializers : dict[str, tuple[Callable, Callable, bool]] | None
An optional dict mapping a name to a (de)serializer.
See `dask_serialize` and `dask_deserialize` for more.
See Also
--------
serialize
"""
if "is-collection" in header:
headers = header["sub-headers"]
lengths = header["frame-lengths"]
cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[
header["type-serialized"]
]
start = 0
if cls is dict:
d = {}
for _header, _length in zip(headers, lengths):
k = _header.pop("key")
d[k] = deserialize(
_header,
frames[start : start + _length],
deserializers=deserializers,
)
start += _length
return d
else:
lst = []
for _header, _length in zip(headers, lengths):
lst.append(
deserialize(
_header,
frames[start : start + _length],
deserializers=deserializers,
)
)
start += _length
return cls(lst)
name = header.get("serializer")
if deserializers is not None and name not in deserializers:
raise TypeError(
"Data serialized with %s but only able to deserialize "
"data with %s" % (name, str(list(deserializers)))
)
dumps, loads, wants_context = families[name]
return loads(header, frames)
def serialize_and_split(
x, serializers=None, on_error="message", context=None, size=None
):
"""Serialize and split compressible frames
This function is a drop-in replacement of `serialize()` that calls `serialize()`
followed by `frame_split_size()` on frames that should be compressed.
Use `merge_and_deserialize()` to merge and deserialize the frames back.
See Also
--------
serialize
merge_and_deserialize
"""
header, frames = serialize(x, serializers, on_error, context)
num_sub_frames = []
offsets = []
out_frames = []
out_compression = []
for frame, compression in zip(
frames, header.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
sub_frames = frame_split_size(frame, n=size)
num_sub_frames.append(len(sub_frames))
offsets.append(len(out_frames))
out_frames.extend(sub_frames)
out_compression.extend([None] * len(sub_frames))
else:
num_sub_frames.append(1)
offsets.append(len(out_frames))
out_frames.append(frame)
out_compression.append(compression)
assert len(out_compression) == len(out_frames)
# Notice, in order to match msgpack's implicit conversion to tuples,
# we convert to tuples here as well.
header["split-num-sub-frames"] = tuple(num_sub_frames)
header["split-offsets"] = tuple(offsets)
header["compression"] = tuple(out_compression)
return header, out_frames
def merge_and_deserialize(header, frames, deserializers=None):
"""Merge and deserialize frames
This function is a drop-in replacement of `deserialize()` that merges
frames that were split by `serialize_and_split()`
See Also
--------
deserialize
serialize_and_split
"""
if "split-num-sub-frames" not in header:
merged_frames = frames
else:
merged_frames = []
for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
subframes = frames[offset : offset + n]
try:
merged = merge_memoryviews(subframes)
except (ValueError, TypeError):
merged = bytearray().join(subframes)
merged_frames.append(merged)
return deserialize(header, merged_frames, deserializers=deserializers)
class Serialize:
"""Mark an object that should be serialized
Examples
--------
>>> msg = {'op': 'update', 'data': to_serialize(123)}
>>> msg # doctest: +SKIP
{'op': 'update', 'data': <Serialize: 123>}
See also
--------
distributed.protocol.dumps
"""
def __init__(self, data):
self.data = data
def __repr__(self):
return f"<Serialize: {self.data}>"
def __eq__(self, other):
return isinstance(other, Serialize) and other.data == self.data
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash(self.data)
to_serialize = Serialize
class Serialized:
"""An object that is already serialized into header and frames
Normal serialization operations pass these objects through. This is
typically used within the scheduler which accepts messages that contain
data without actually unpacking that data.
"""
def __init__(self, header, frames):
self.header = header
self.frames = frames
def __eq__(self, other):
return (
isinstance(other, Serialized)
and other.header == self.header
and other.frames == self.frames
)
def __ne__(self, other):
return not (self == other)
class ToPickle:
"""Mark an object that should be pickled
Both the scheduler and workers with automatically unpickle this
object on arrival.
Notice, this requires that the scheduler is allowed to use pickle.
If the configuration option "distributed.scheduler.pickle" is set
to False, the scheduler will raise an exception instead.
"""
def __init__(self, data):
self.data = data
def __repr__(self):
return "<ToPickle: %s>" % str(self.data)
def __eq__(self, other):
return isinstance(other, type(self)) and other.data == self.data
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash(self.data)
class Pickled:
"""An object that is already pickled into header and frames
Normal pickled objects are unpickled by the scheduler.
"""
def __init__(self, header, frames):
self.header = header
self.frames = frames
def __eq__(self, other):
return (
isinstance(other, type(self))
and other.header == self.header
and other.frames == self.frames
)
def __ne__(self, other):
return not (self == other)
def nested_deserialize(x):
"""
Replace all Serialize and Serialized values nested in *x*
with the original values. Returns a copy of *x*.
>>> msg = {'op': 'update', 'data': to_serialize(123)}
>>> nested_deserialize(msg)
{'op': 'update', 'data': 123}
"""
def replace_inner(x):
if type(x) is dict:
x = x.copy()
for k, v in x.items():
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)
elif type(x) is list:
x = list(x)
for k, v in enumerate(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)
return x
return replace_inner(x)
def serialize_bytelist(x, **kwargs):
header, frames = serialize_and_split(x, **kwargs)
if frames:
compression, frames = zip(*map(maybe_compress, frames))
else:
compression = []
header["compression"] = compression
header["count"] = len(frames)
header = msgpack.dumps(header, use_bin_type=True)
frames2 = [header, *frames]
frames2.insert(0, pack_frames_prelude(frames2))
return frames2
def serialize_bytes(x, **kwargs):
L = serialize_bytelist(x, **kwargs)
return b"".join(L)
def deserialize_bytes(b):
frames = unpack_frames(b)
header, frames = frames[0], frames[1:]
if header:
header = msgpack.loads(header, raw=False, use_list=False)
else:
header = {}
frames = decompress(header, frames)
return merge_and_deserialize(header, frames)
################################
# Class specific serialization #
################################
def register_serialization(cls, serialize, deserialize):
"""Register a new class for dask-custom serialization
Parameters
----------
cls : type
serialize : callable(cls) -> Tuple[Dict, List[bytes]]
deserialize : callable(header: Dict, frames: List[bytes]) -> cls
Examples
--------
>>> class Human:
... def __init__(self, name):
... self.name = name
>>> def serialize(human):
... header = {}
... frames = [human.name.encode()]
... return header, frames
>>> def deserialize(header, frames):
... return Human(frames[0].decode())
>>> register_serialization(Human, serialize, deserialize)
>>> serialize(Human('Alice'))
({}, [b'Alice'])
See Also
--------
serialize
deserialize
"""
if isinstance(cls, str):
raise TypeError(
"Strings are no longer accepted for type registration. "
"Use dask_serialize.register_lazy instead"
)
dask_serialize.register(cls)(serialize)
dask_deserialize.register(cls)(deserialize)
def register_serialization_lazy(toplevel, func):
"""Register a registration function to be called if *toplevel*
module is ever loaded.
"""
raise Exception("Serialization registration has changed. See documentation")
@partial(normalize_token.register, Serialized)
def normalize_Serialized(o):
return [o.header] + o.frames # for dask.base.tokenize
# Teach serialize how to handle bytes
@dask_serialize.register(bytes)
def _serialize_bytes(obj):
header = {} # no special metadata
frames = [obj]
return header, frames
# Teach serialize how to handle bytestrings
@dask_serialize.register(bytearray)
def _serialize_bytearray(obj):
header = {} # no special metadata
frames = [obj]
return header, frames
@dask_deserialize.register(bytes)
def _deserialize_bytes(header, frames):
if len(frames) == 1 and isinstance(frames[0], bytes):
return frames[0]
else:
return b"".join(frames)
@dask_deserialize.register(bytearray)
def _deserialize_bytearray(header, frames):
if len(frames) == 1 and isinstance(frames[0], bytearray):
return frames[0]
else:
return bytearray().join(frames)
@dask_serialize.register(array)
def _serialize_array(obj):
header = {"typecode": obj.typecode, "writeable": (None,)}
frames = [memoryview(obj)]
return header, frames
@dask_deserialize.register(array)
def _deserialize_array(header, frames):
a = array(header["typecode"])
nframes = len(frames)
if nframes == 1:
a.frombytes(ensure_memoryview(frames[0]))
elif nframes > 1:
a.frombytes(b"".join(map(ensure_memoryview, frames)))
return a
@dask_serialize.register(memoryview)
def _serialize_memoryview(obj):
if obj.format == "O":
raise ValueError("Cannot serialize `memoryview` containing Python objects")
if not obj and obj.ndim > 1:
raise ValueError("Cannot serialize empty non-1-D `memoryview`")
header = {"format": obj.format, "shape": obj.shape}
frames = [obj]
return header, frames
@dask_deserialize.register(memoryview)
def _deserialize_memoryview(header, frames):
if len(frames) == 1:
out = ensure_memoryview(frames[0])
else:
out = memoryview(b"".join(frames))
# handle empty `memoryview`s
if out:
out = out.cast(header["format"], header["shape"])
else:
out = out.cast(header["format"])
assert out.shape == header["shape"]
return out
#########################
# Descend into __dict__ #
#########################
def _is_msgpack_serializable(v):
typ = type(v)
return (
v is None
or typ is str
or typ is bool
or typ is int
or typ is float
or isinstance(v, dict)
and all(map(_is_msgpack_serializable, v.values()))
and all(typ is str for x in v.keys())
or isinstance(v, (list, tuple))
and all(map(_is_msgpack_serializable, v))
)
class ObjectDictSerializer:
def __init__(self, serializer):
self.serializer = serializer
def serialize(self, est):
header = {
"serializer": self.serializer,
"type-serialized": pickle.dumps(type(est)),
"simple": {},
"complex": {},
}
frames = []
if isinstance(est, dict):
d = est
else:
d = est.__dict__
for k, v in d.items():
if _is_msgpack_serializable(v):
header["simple"][k] = v
else:
if isinstance(v, dict):
h, f = self.serialize(v)
h = {"nested-dict": h}
else:
h, f = serialize(v, serializers=(self.serializer, "pickle"))
header["complex"][k] = {
"header": h,
"start": len(frames),
"stop": len(frames) + len(f),
}
frames += f
return header, frames
def deserialize(self, header, frames):
cls = pickle.loads(header["type-serialized"])
if issubclass(cls, dict):
dd = obj = {}
else:
obj = object.__new__(cls)
dd = obj.__dict__
dd.update(header["simple"])
for k, d in header["complex"].items():
h = d["header"]
f = frames[d["start"] : d["stop"]]
nested_dict = h.get("nested-dict")
if nested_dict:
v = self.deserialize(nested_dict, f)
else:
v = deserialize(h, f)
dd[k] = v
return obj
dask_object_with_dict_serializer = ObjectDictSerializer("dask")
dask_deserialize.register(dict)(dask_object_with_dict_serializer.deserialize)
def register_generic(
cls,
serializer_name="dask",
serialize_func=dask_serialize,
deserialize_func=dask_deserialize,
):
"""Register (de)serialize to traverse through __dict__
Normally when registering new classes for Dask's custom serialization you
need to manage headers and frames, which can be tedious. If all you want
to do is traverse through your object and apply serialize to all of your
object's attributes then this function may provide an easier path.
This registers a class for the custom Dask serialization family. It
serializes it by traversing through its __dict__ of attributes and applying
``serialize`` and ``deserialize`` recursively. It collects a set of frames
and keeps small attributes in the header. Deserialization reverses this
process.
This is a good idea if the following hold:
1. Most of the bytes of your object are composed of data types that Dask's
custom serializtion already handles well, like Numpy arrays.
2. Your object doesn't require any special constructor logic, other than
object.__new__(cls)
Examples
--------
>>> import sklearn.base
>>> from distributed.protocol import register_generic
>>> register_generic(sklearn.base.BaseEstimator)
See Also
--------
dask_serialize
dask_deserialize
"""
object_with_dict_serializer = ObjectDictSerializer(serializer_name)
serialize_func.register(cls)(object_with_dict_serializer.serialize)
deserialize_func.register(cls)(object_with_dict_serializer.deserialize)
|