File: memory_sampler.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (223 lines) | stat: -rw-r--r-- 7,000 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from __future__ import annotations

import uuid
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from typing import TYPE_CHECKING, Any, cast

from distributed.compatibility import PeriodicCallback

if TYPE_CHECKING:
    # Optional runtime dependencies
    import pandas as pd

    # Circular dependencies
    from distributed.client import Client
    from distributed.scheduler import Scheduler


class MemorySampler:
    """Sample cluster-wide memory usage every <interval> seconds.

    **Usage**

    .. code-block:: python

       client = Client()
       ms = MemorySampler()

       with ms.sample("run 1"):
           <run first workflow>
       with ms.sample("run 2"):
           <run second workflow>
       ...
       ms.plot()

    or with an asynchronous client:

    .. code-block:: python

       client = await Client(asynchronous=True)
       ms = MemorySampler()

       async with ms.sample("run 1"):
           <run first workflow>
       async with ms.sample("run 2"):
           <run second workflow>
       ...
       ms.plot()
    """

    samples: dict[str, list[tuple[float, int]]]

    def __init__(self):
        self.samples = {}

    def sample(
        self,
        label: str | None = None,
        *,
        client: Client | None = None,
        measure: str = "process",
        interval: float = 0.5,
    ) -> Any:
        """Context manager that records memory usage in the cluster.
        This is synchronous if the client is synchronous and
        asynchronous if the client is asynchronous.

        The samples are recorded in ``self.samples[<label>]``.

        Parameters
        ==========
        label: str, optional
            Tag to record the samples under in the self.samples dict.
            Default: automatically generate a random label
        client: Client, optional
            client used to connect to the scheduler.
            Default: use the global client
        measure: str, optional
            One of the measures from :class:`distributed.scheduler.MemoryState`.
            Default: sample process memory
        interval: float, optional
            sampling interval, in seconds.
            Default: 0.5
        """
        if not client:
            from distributed.client import get_client

            client = get_client()

        if client.asynchronous:
            return self._sample_async(label, client, measure, interval)
        else:
            return self._sample_sync(label, client, measure, interval)

    @contextmanager
    def _sample_sync(
        self, label: str | None, client: Client, measure: str, interval: float
    ) -> Iterator[None]:
        key = client.sync(
            client.scheduler.memory_sampler_start,
            client=client.id,
            measure=measure,
            interval=interval,
        )
        try:
            yield
        finally:
            samples = client.sync(client.scheduler.memory_sampler_stop, key=key)
            self.samples[label or key] = samples

    @asynccontextmanager
    async def _sample_async(
        self, label: str | None, client: Client, measure: str, interval: float
    ) -> AsyncIterator[None]:
        key = await client.scheduler.memory_sampler_start(
            client=client.id, measure=measure, interval=interval
        )
        try:
            yield
        finally:
            samples = await client.scheduler.memory_sampler_stop(key=key)
            self.samples[label or key] = samples

    def to_pandas(self, *, align: bool = False) -> pd.DataFrame:
        """Return the data series as a pandas.Dataframe.

        Parameters
        ==========
        align : bool, optional
            If True, change the absolute timestamps into time deltas from the first
            sample of each series, so that different series can be visualized side by
            side. If False (the default), use absolute timestamps.
        """
        import pandas as pd

        ss = {}
        for (label, s_list) in self.samples.items():
            assert s_list  # There's always at least one sample
            s = pd.DataFrame(s_list).set_index(0)[1]
            s.index = pd.to_datetime(s.index, unit="s")
            s.name = label
            if align:
                # convert datetime to timedelta from the first sample
                s.index -= cast(pd.Timestamp, s.index[0])
            ss[label] = s

        df = pd.DataFrame(ss)

        if len(ss) > 1:
            # Forward-fill NaNs in the middle of a series created either by overlapping
            # sampling time range or by align=True. Do not ffill series beyond their
            # last sample.
            df = df.ffill().where(~pd.isna(df.bfill()))

        return df

    def plot(self, *, align: bool = False, **kwargs: Any) -> Any:
        """Plot data series collected so far

        Parameters
        ==========
        align : bool (optional)
            See :meth:`~distributed.diagnostics.MemorySampler.to_pandas`
        kwargs
            Passed verbatim to :meth:`pandas.DataFrame.plot`

        Returns
        =======
        Output of :meth:`pandas.DataFrame.plot`
        """
        df = self.to_pandas(align=align) / 2**30
        return df.plot(
            xlabel="time",
            ylabel="Cluster memory (GiB)",
            **kwargs,
        )


class MemorySamplerExtension:
    """Scheduler extension - server side of MemorySampler"""

    scheduler: Scheduler
    samples: dict[str, list[tuple[float, int]]]

    def __init__(self, scheduler: Scheduler):
        self.scheduler = scheduler
        self.scheduler.extensions["memory_sampler"] = self
        self.scheduler.handlers["memory_sampler_start"] = self.start
        self.scheduler.handlers["memory_sampler_stop"] = self.stop
        self.samples = {}

    def start(self, client: str, measure: str, interval: float) -> str:
        """Start periodically sampling memory"""
        assert not measure.startswith("_")
        assert isinstance(getattr(self.scheduler.memory, measure), int)

        key = str(uuid.uuid4())
        self.samples[key] = []

        def sample():
            if client in self.scheduler.clients:
                ts = datetime.now().timestamp()
                nbytes = getattr(self.scheduler.memory, measure)
                self.samples[key].append((ts, nbytes))
            else:
                self.stop(key)

        pc = PeriodicCallback(sample, interval * 1000)
        self.scheduler.periodic_callbacks["MemorySampler-" + key] = pc
        pc.start()

        # Immediately collect the first sample; this also ensures there's always at
        # least one sample
        sample()

        return key

    def stop(self, key: str) -> list[tuple[float, int]]:
        """Stop sampling and return the samples"""
        pc = self.scheduler.periodic_callbacks.pop("MemorySampler-" + key)
        pc.stop()
        return self.samples.pop(key)