File: compression.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (205 lines) | stat: -rw-r--r-- 6,672 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
Record known compressors

Includes utilities for determining whether or not to compress
"""

from __future__ import annotations

import zlib
from collections.abc import Callable, Iterable
from contextlib import suppress
from functools import partial
from random import randint
from typing import TYPE_CHECKING, Any, Literal, NamedTuple

from packaging.version import parse as parse_version
from tlz import identity

import dask

from distributed.metrics import context_meter
from distributed.utils import ensure_memoryview, nbytes

if TYPE_CHECKING:
    # TODO import from typing (requires Python >=3.10)
    from typing_extensions import TypeAlias

# TODO remove quotes (requires Python >=3.10)
AnyBytes: TypeAlias = "bytes | bytearray | memoryview"


class Compression(NamedTuple):
    name: None | str
    compress: Callable[[AnyBytes], AnyBytes]
    decompress: Callable[[AnyBytes], AnyBytes]


compressions: dict[str | None | Literal[False], Compression] = {
    None: Compression(None, identity, identity),
    False: Compression(None, identity, identity),  # alias
    "auto": Compression(None, identity, identity),
    "zlib": Compression("zlib", zlib.compress, zlib.decompress),
}


with suppress(ImportError):
    import snappy

    # In python-snappy 0.5.3, support for the Python Buffer Protocol was added.
    # This is needed to handle other objects (like `memoryview`s) without
    # copying to `bytes` first.
    #
    # Note: `snappy.__version__` doesn't exist in a release yet.
    #       So do a little test that will fail if snappy is not 0.5.3 or later.
    try:
        snappy.compress(memoryview(b""))
    except TypeError:
        raise ImportError("Need snappy >= 0.5.3")

    compressions["snappy"] = Compression("snappy", snappy.compress, snappy.decompress)
    compressions["auto"] = compressions["snappy"]

with suppress(ImportError):
    import lz4

    # Required to use `lz4.block` APIs and Python Buffer Protocol support.
    if parse_version(lz4.__version__) < parse_version("0.23.1"):
        raise ImportError("Need lz4 >= 0.23.1")

    import lz4.block

    compressions["lz4"] = Compression(
        "lz4",
        lz4.block.compress,
        # Avoid expensive deep copies when deserializing writeable numpy arrays
        # See distributed.protocol.numpy.deserialize_numpy_ndarray
        # Note that this is only useful for buffers smaller than distributed.comm.shard;
        # larger ones are deep-copied between decompression and serialization anyway in
        # order to merge them.
        partial(lz4.block.decompress, return_bytearray=True),
    )
    compressions["auto"] = compressions["lz4"]


with suppress(ImportError):
    import zstandard

    # Required for Python Buffer Protocol support.
    if parse_version(zstandard.__version__) < parse_version("0.9.0"):
        raise ImportError("Need zstandard >= 0.9.0")

    def zstd_compress(data):
        zstd_compressor = zstandard.ZstdCompressor(
            level=dask.config.get("distributed.comm.zstd.level"),
            threads=dask.config.get("distributed.comm.zstd.threads"),
        )
        return zstd_compressor.compress(data)

    def zstd_decompress(data):
        zstd_decompressor = zstandard.ZstdDecompressor()
        return zstd_decompressor.decompress(data)

    compressions["zstd"] = Compression("zstd", zstd_compress, zstd_decompress)


def get_compression_settings(key: str) -> str | None:
    """Fetch and validate compression settings, with a nice error message in case of
    failure. This also resolves 'auto', which may differ between different hosts of the
    same cluster.
    """
    name = dask.config.get(key)
    try:
        return compressions[name].name
    except KeyError:
        valid = ",".join(repr(n) for n in compressions)
        raise ValueError(
            f"Invalid compression setting {key}={name}. Valid options are {valid}."
        )


def byte_sample(b: memoryview, size: int, n: int) -> memoryview:
    """Sample a bytestring from many locations

    Parameters
    ----------
    b : full memoryview
    size : int
        target size of each sample to collect
        (may be smaller if samples collide)
    n : int
        number of samples to collect
    """
    assert size >= 0 and n >= 0
    if size == 0 or n == 0:
        return memoryview(b"")

    parts = []
    max_start = b.nbytes - size
    start = randint(0, max_start)
    for _ in range(n - 1):
        next_start = randint(0, max_start)
        end = min(start + size, next_start)
        parts.append(b[start:end])
        start = next_start
    parts.append(b[start : start + size])

    if n == 1:
        return parts[0]
    else:
        return memoryview(b"".join(parts))


@context_meter.meter("compress")
def maybe_compress(
    payload: bytes | bytearray | memoryview,
    *,
    min_size: int = 10_000,
    sample_size: int = 10_000,
    nsamples: int = 5,
    min_ratio: float = 0.7,
    compression: str | None | Literal[False] = "auto",
) -> tuple[str | None, AnyBytes]:
    """Maybe compress payload

    1. Don't compress payload if smaller than min_size
    2. Sample the payload in <nsamples> spots, compress those, and if it doesn't
       compress to at least <min_ratio> to the original, return the original
    3. Then compress the full original; it doesn't compress at least to <min_ratio>,
       return the original
    4. Return the compressed output

    Returns
    -------
    - Name of compression algorithm used
    - Either compressed or original payload
    """
    comp = compressions[compression]
    if not comp.name:
        return None, payload
    if not (min_size <= nbytes(payload) <= 2**31):
        # Either too small to bother
        # or too large (compression libraries often fail)
        return None, payload

    # Take a view of payload for efficient usage
    mv = ensure_memoryview(payload)

    # Try compressing a sample to see if it compresses well
    sample = byte_sample(mv, sample_size, nsamples)
    if len(comp.compress(sample)) <= min_ratio * sample.nbytes:
        # Try compressing the real thing and check how compressed it is
        compressed = comp.compress(mv)
        if len(compressed) <= min_ratio * mv.nbytes:
            return comp.name, compressed
    # Skip compression as the sample or the data didn't compress well
    return None, payload


@context_meter.meter("decompress")
def decompress(header: dict[str, Any], frames: Iterable[AnyBytes]) -> list[AnyBytes]:
    """Decompress frames according to information in the header"""
    return [
        compressions[name].decompress(frame)
        for name, frame in zip(header["compression"], frames)
    ]