File: _disk.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (228 lines) | stat: -rw-r--r-- 8,248 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from __future__ import annotations

import contextlib
import errno
import pathlib
import shutil
import threading
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from typing import Any

from toolz import concat

from distributed.metrics import context_meter, thread_time
from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._exceptions import DataUnavailable, P2POutOfDiskError
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import pickle_bytelist
from distributed.utils import Deadline, empty_context, log_errors, nbytes


class ReadWriteLock:
    _condition: threading.Condition
    _n_reads: int
    _write_pending: bool

    def __init__(self) -> None:
        self._condition = threading.Condition(threading.Lock())
        self._n_reads = 0
        self._write_pending = False
        self._write_active = False

    def acquire_write(self, timeout: float = -1) -> bool:
        deadline = Deadline.after(timeout if timeout >= 0 else None)
        with self._condition:
            result = self._condition.wait_for(
                lambda: not self._write_pending, timeout=deadline.remaining
            )
            if result is False:
                return False

            self._write_pending = True
            result = self._condition.wait_for(
                lambda: self._n_reads == 0, timeout=deadline.remaining
            )

            if result is False:
                self._write_pending = False
                self._condition.notify_all()
                return False
            self._write_active = True
            return True

    def release_write(self) -> None:
        with self._condition:
            if self._write_active is False:
                raise RuntimeError("Tried releasing unlocked write lock")
            self._write_pending = False
            self._write_active = False
            self._condition.notify_all()

    def acquire_read(self, timeout: float = -1) -> bool:
        deadline = Deadline.after(timeout if timeout >= 0 else None)
        with self._condition:
            result = self._condition.wait_for(
                lambda: not self._write_pending, timeout=deadline.remaining
            )
            if result is False:
                return False
            self._n_reads += 1
            return True

    def release_read(self) -> None:
        with self._condition:
            if self._n_reads == 0:
                raise RuntimeError("Tired releasing unlocked read lock")
            self._n_reads -= 1
            if self._n_reads == 0:
                self._condition.notify_all()

    @contextmanager
    def write(self) -> Generator[None]:
        self.acquire_write()
        try:
            yield
        finally:
            self.release_write()

    @contextmanager
    def read(self) -> Generator[None]:
        self.acquire_read()
        try:
            yield
        finally:
            self.release_read()


class DiskShardsBuffer(ShardsBuffer):
    """Accept, buffer, and write many small objects to many files

    This takes in lots of small objects, writes them to a local directory, and
    then reads them back when all writes are complete.  It buffers these
    objects in memory so that it can optimize disk access for larger writes.

    **State**

    -   shards: dict[str, list[bytes]]

        This is our in-memory buffer of data waiting to be written to files.

    -   sizes: dict[str, int]

        The size of each list of shards.  We find the largest and write data from that buffer

    Parameters
    ----------
    directory : str or pathlib.Path
        Where to write and read data.  Ideally points to fast disk.
    memory_limiter : ResourceLimiter
        Limiter for in-memory buffering (at most this much data)
        before writes to disk occur. If the incoming data that has yet
        to be processed exceeds this limit, then the buffer will block
        until below the threshold. See :meth:`.write` for the
        implementation of this scheme.
    """

    def __init__(
        self,
        directory: str | pathlib.Path,
        read: Callable[[pathlib.Path], tuple[Any, int]],
        memory_limiter: ResourceLimiter,
    ):
        super().__init__(
            memory_limiter=memory_limiter,
            # Disk is not able to run concurrently atm
            concurrency_limit=1,
        )
        self.directory = pathlib.Path(directory)
        self.directory.mkdir(exist_ok=True)
        self._closed = False
        self._read = read
        self._directory_lock = ReadWriteLock()

    @log_errors
    async def _process(self, id: str, shards: list[Any]) -> None:
        """Write one buffer to file

        This function was built to offload the disk IO, but since then we've
        decided to keep this within the event loop (disk bandwidth should be
        prioritized, and writes are typically small enough to not be a big
        deal).

        Most of the logic here is about possibly going back to a separate
        thread, or about diagnostics.  If things don't change much in the
        future then we should consider simplifying this considerably and
        dropping the write into communicate above.
        """
        frames: Iterable[bytes | bytearray | memoryview]
        if isinstance(shards[0], bytes):
            # Manually serialized dataframes
            frames = shards
            serialize_meter_ctx: Any = empty_context
        else:
            # Unserialized numpy arrays
            # Note: no calls to pickle_bytelist will happen until we actually start
            # writing to disk below.
            frames = concat(pickle_bytelist(shard) for shard in shards)
            serialize_meter_ctx = context_meter.meter("serialize", func=thread_time)

        with (
            self._directory_lock.read(),
            context_meter.meter("disk-write"),
            serialize_meter_ctx,
        ):
            # Consider boosting total_size a bit here to account for duplication
            # We only need shared (i.e., read) access to the directory to write
            # to a file inside of it.
            if self._closed:
                raise RuntimeError("Already closed")

            try:
                self._write_frames(frames, id)
            except OSError as e:
                if e.errno == errno.ENOSPC:
                    raise P2POutOfDiskError from e
                raise
        context_meter.digest_metric("disk-write", 1, "count")
        context_meter.digest_metric("disk-write", sum(map(nbytes, frames)), "bytes")

    def _write_frames(
        self, frames: Iterable[bytes | bytearray | memoryview], id: str
    ) -> None:
        with open(self.directory / str(id), mode="ab") as f:
            f.writelines(frames)

    def read(self, id: str) -> Any:
        """Read a complete file back into memory"""
        self.raise_on_exception()
        if not self._inputs_done:
            raise RuntimeError("Tried to read from file before done.")

        try:
            with self._directory_lock.read():
                if self._closed:
                    raise RuntimeError("Already closed")
                fname = (self.directory / str(id)).resolve()
                # Note: don't add `with context_meter.meter("p2p-disk-read"):` to
                # measure seconds here, as it would shadow "p2p-get-output-cpu" and
                # "p2p-get-output-noncpu". Also, for rechunk it would not measure
                # the whole disk access, as _read returns memory-mapped buffers.
                data, size = self._read(fname)
                context_meter.digest_metric("p2p-disk-read", 1, "count")
                context_meter.digest_metric("p2p-disk-read", size, "bytes")
        except FileNotFoundError:
            raise DataUnavailable(id)

        if data:
            self.bytes_read += size
            return data
        else:
            raise DataUnavailable(id)

    async def close(self) -> None:
        await super().close()
        with self._directory_lock.write():
            self._closed = True
            with contextlib.suppress(FileNotFoundError):
                shutil.rmtree(self.directory)