1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
|
from __future__ import annotations
import math
import numpy as np
from distributed.protocol import pickle
from distributed.protocol.serialize import dask_deserialize, dask_serialize
from distributed.utils import log_errors
def itemsize(dt):
"""Itemsize of dtype
Try to return the itemsize of the base element, return 8 as a fallback
"""
result = dt.base.itemsize
if result > 255:
result = 8
return result
@dask_serialize.register(np.ndarray)
def serialize_numpy_ndarray(x, context=None):
if x.dtype.hasobject or (x.dtype.flags & np.core.multiarray.LIST_PICKLE):
header = {"pickle": True}
frames = [None]
def buffer_callback(f):
frames.append(memoryview(f))
frames[0] = pickle.dumps(
x,
buffer_callback=buffer_callback,
protocol=(context or {}).get("pickle-protocol", None),
)
return header, frames
# We cannot blindly pickle the dtype as some may fail pickling,
# so we have a mixture of strategies.
if x.dtype.kind == "V":
# Preserving all the information works best when pickling
try:
# Only use stdlib pickle as cloudpickle is slow when failing
# (microseconds instead of nanoseconds)
dt = (
1,
pickle.pickle.dumps(
x.dtype, protocol=(context or {}).get("pickle-protocol", None)
),
)
pickle.loads(dt[1]) # does it unpickle fine?
except Exception:
# dtype fails pickling => fall back on the descr if reasonable.
if x.dtype.type is not np.void or x.dtype.alignment != 1:
raise
else:
dt = (0, x.dtype.descr)
else:
dt = (0, x.dtype.str)
# Only serialize broadcastable data for arrays with zero strided axes
broadcast_to = None
if 0 in x.strides:
broadcast_to = x.shape
strides = x.strides
writeable = x.flags.writeable
x = x[tuple(slice(None) if s != 0 else slice(1) for s in strides)]
if not x.flags.c_contiguous and not x.flags.f_contiguous:
# Broadcasting can only be done with contiguous arrays
x = np.ascontiguousarray(x)
x = np.lib.stride_tricks.as_strided(
x,
strides=[j if i != 0 else i for i, j in zip(strides, x.strides)],
writeable=writeable,
)
if not x.shape:
# 0d array
strides = x.strides
data = x.ravel()
elif x.flags.c_contiguous or x.flags.f_contiguous:
# Avoid a copy and respect order when unserializing
strides = x.strides
data = x.ravel(order="K")
else:
x = np.ascontiguousarray(x)
strides = x.strides
data = x.ravel()
if data.dtype.fields or data.dtype.itemsize > 8:
data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8))
try:
data = data.data
except ValueError:
# "ValueError: cannot include dtype 'M' in a buffer"
data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data
header = {
"dtype": dt,
"shape": x.shape,
"strides": strides,
"writeable": [x.flags.writeable],
}
if broadcast_to is not None:
header["broadcast_to"] = broadcast_to
frames = [data]
return header, frames
@dask_deserialize.register(np.ndarray)
@log_errors
def deserialize_numpy_ndarray(header, frames):
if header.get("pickle"):
return pickle.loads(frames[0], buffers=frames[1:])
(frame,) = frames
(writeable,) = header["writeable"]
is_custom, dt = header["dtype"]
if is_custom:
dt = pickle.loads(dt)
else:
dt = np.dtype(dt)
if header.get("broadcast_to"):
shape = header["broadcast_to"]
else:
shape = header["shape"]
x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"])
if not writeable:
x.flags.writeable = False
else:
x = np.require(x, requirements=["W"])
return x
@dask_serialize.register(np.ma.core.MaskedConstant)
def serialize_numpy_ma_masked(x):
return {}, []
@dask_deserialize.register(np.ma.core.MaskedConstant)
def deserialize_numpy_ma_masked(header, frames):
return np.ma.masked
@dask_serialize.register(np.ma.core.MaskedArray)
def serialize_numpy_maskedarray(x, context=None):
data_header, frames = serialize_numpy_ndarray(x.data)
header = {"data-header": data_header, "nframes": len(frames)}
# Serialize mask if present
if x.mask is not np.ma.nomask:
mask_header, mask_frames = serialize_numpy_ndarray(x.mask)
header["mask-header"] = mask_header
frames += mask_frames
# Only a few dtypes have python equivalents msgpack can serialize
if isinstance(x.fill_value, (np.integer, np.floating, np.bool_)):
serialized_fill_value = (False, x.fill_value.item())
else:
serialized_fill_value = (
True,
pickle.dumps(
x.fill_value, protocol=(context or {}).get("pickle-protocol", None)
),
)
header["fill-value"] = serialized_fill_value
return header, frames
@dask_deserialize.register(np.ma.core.MaskedArray)
def deserialize_numpy_maskedarray(header, frames):
data_header = header["data-header"]
data_frames = frames[: header["nframes"]]
data = deserialize_numpy_ndarray(data_header, data_frames)
if "mask-header" in header:
mask_header = header["mask-header"]
mask_frames = frames[header["nframes"] :]
mask = deserialize_numpy_ndarray(mask_header, mask_frames)
else:
mask = np.ma.nomask
pickled_fv, fill_value = header["fill-value"]
if pickled_fv:
fill_value = pickle.loads(fill_value)
return np.ma.masked_array(data, mask=mask, fill_value=fill_value)
|