1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
"""
Record known compressors
Includes utilities for determining whether or not to compress
"""
from contextlib import suppress
from functools import partial
import logging
import random
import dask
from tlz import identity
try:
import blosc
n = blosc.set_nthreads(2)
if hasattr("blosc", "releasegil"):
blosc.set_releasegil(True)
except ImportError:
blosc = False
from ..utils import ensure_bytes
compressions = {None: {"compress": identity, "decompress": identity}}
compressions[False] = compressions[None] # alias
default_compression = None
logger = logging.getLogger(__name__)
with suppress(ImportError):
import zlib
compressions["zlib"] = {"compress": zlib.compress, "decompress": zlib.decompress}
with suppress(ImportError):
import snappy
def _fixed_snappy_decompress(data):
# snappy.decompress() doesn't accept memoryviews
if isinstance(data, (memoryview, bytearray)):
data = bytes(data)
return snappy.decompress(data)
compressions["snappy"] = {
"compress": snappy.compress,
"decompress": _fixed_snappy_decompress,
}
default_compression = "snappy"
with suppress(ImportError):
import lz4
try:
# try using the new lz4 API
import lz4.block
lz4_compress = lz4.block.compress
lz4_decompress = lz4.block.decompress
except ImportError:
# fall back to old one
lz4_compress = lz4.LZ4_compress
lz4_decompress = lz4.LZ4_uncompress
# helper to bypass missing memoryview support in current lz4
# (fixed in later versions)
def _fixed_lz4_compress(data):
try:
return lz4_compress(data)
except TypeError:
if isinstance(data, (memoryview, bytearray)):
return lz4_compress(bytes(data))
else:
raise
def _fixed_lz4_decompress(data):
try:
return lz4_decompress(data)
except (ValueError, TypeError):
if isinstance(data, (memoryview, bytearray)):
return lz4_decompress(bytes(data))
else:
raise
compressions["lz4"] = {
"compress": _fixed_lz4_compress,
"decompress": _fixed_lz4_decompress,
}
default_compression = "lz4"
with suppress(ImportError):
import zstandard
zstd_compressor = zstandard.ZstdCompressor(
level=dask.config.get("distributed.comm.zstd.level"),
threads=dask.config.get("distributed.comm.zstd.threads"),
)
zstd_decompressor = zstandard.ZstdDecompressor()
def zstd_compress(data):
return zstd_compressor.compress(data)
def zstd_decompress(data):
return zstd_decompressor.decompress(data)
compressions["zstd"] = {"compress": zstd_compress, "decompress": zstd_decompress}
with suppress(ImportError):
import blosc
compressions["blosc"] = {
"compress": partial(blosc.compress, clevel=5, cname="lz4"),
"decompress": blosc.decompress,
}
def get_default_compression():
default = dask.config.get("distributed.comm.compression")
if default != "auto":
if default in compressions:
return default
else:
raise ValueError(
"Default compression '%s' not found.\n"
"Choices include auto, %s"
% (default, ", ".join(sorted(map(str, compressions))))
)
else:
return default_compression
get_default_compression()
def byte_sample(b, size, n):
"""Sample a bytestring from many locations
Parameters
----------
b: bytes or memoryview
size: int
size of each sample to collect
n: int
number of samples to collect
"""
starts = [random.randint(0, len(b) - size) for j in range(n)]
ends = []
for i, start in enumerate(starts[:-1]):
ends.append(min(start + size, starts[i + 1]))
ends.append(starts[-1] + size)
parts = [b[start:end] for start, end in zip(starts, ends)]
return b"".join(map(ensure_bytes, parts))
def maybe_compress(
payload,
min_size=1e4,
sample_size=1e4,
nsamples=5,
compression=dask.config.get("distributed.comm.compression"),
):
"""
Maybe compress payload
1. We don't compress small messages
2. We sample the payload in a few spots, compress that, and if it doesn't
do any good we return the original
3. We then compress the full original, it it doesn't compress well then we
return the original
4. We return the compressed result
"""
if compression == "auto":
compression = default_compression
if not compression:
return None, payload
if len(payload) < min_size:
return None, payload
if len(payload) > 2 ** 31: # Too large, compression libraries often fail
return None, payload
min_size = int(min_size)
sample_size = int(sample_size)
compress = compressions[compression]["compress"]
# Compress a sample, return original if not very compressed
sample = byte_sample(payload, sample_size, nsamples)
if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible
return None, payload
if type(payload) is memoryview:
nbytes = payload.itemsize * len(payload)
else:
nbytes = len(payload)
if default_compression and blosc and type(payload) is memoryview:
# Blosc does itemsize-aware shuffling, resulting in better compression
compressed = blosc.compress(
payload, typesize=payload.itemsize, cname="lz4", clevel=5
)
compression = "blosc"
else:
compressed = compress(ensure_bytes(payload))
if len(compressed) > 0.9 * nbytes: # full data not very compressible
return None, payload
else:
return compression, compressed
def decompress(header, frames):
""" Decompress frames according to information in the header """
return [
compressions[c]["decompress"](frame)
for c, frame in zip(header["compression"], frames)
]
|