File: recreate_tasks.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (204 lines) | stat: -rw-r--r-- 7,620 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import annotations

import logging

from dask.utils import stringify

from distributed.client import futures_of, wait
from distributed.utils import sync
from distributed.utils_comm import pack_data
from distributed.worker import _deserialize

logger = logging.getLogger(__name__)


class ReplayTaskScheduler:
    """A plugin for the scheduler to recreate tasks locally

    This adds the following routes to the scheduler

    *  get_runspec
    *  get_error_cause
    """

    def __init__(self, scheduler):
        self.scheduler = scheduler
        self.scheduler.handlers["get_runspec"] = self.get_runspec
        self.scheduler.handlers["get_error_cause"] = self.get_error_cause

    def _process_key(self, key):
        if isinstance(key, list):
            key = tuple(key)  # ensure not a list from msgpack
        key = stringify(key)
        return key

    def get_error_cause(self, *args, keys=(), **kwargs):
        for key in keys:
            key = self._process_key(key)
            ts = self.scheduler.tasks.get(key)
            if ts is not None and ts.exception_blame is not None:
                return ts.exception_blame.key

    def get_runspec(self, *args, key=None, **kwargs):
        key = self._process_key(key)
        ts = self.scheduler.tasks.get(key)
        return {"task": ts.run_spec, "deps": [dts.key for dts in ts.dependencies]}


class ReplayTaskClient:
    """
    A plugin for the client allowing replay of remote tasks locally

    Adds the following methods to the given client:

    - ``recreate_error_locally``: main user method for replaying failed tasks
    - ``recreate_task_locally``: main user method for replaying any task
    """

    def __init__(self, client):
        self.client = client
        self.client.extensions["replay-tasks"] = self
        # monkey patch
        self.client._get_raw_components_from_future = (
            self._get_raw_components_from_future
        )
        self.client._prepare_raw_components = self._prepare_raw_components
        self.client._get_components_from_future = self._get_components_from_future
        self.client._get_errored_future = self._get_errored_future
        self.client.recreate_task_locally = self.recreate_task_locally
        self.client.recreate_error_locally = self.recreate_error_locally

    @property
    def scheduler(self):
        return self.client.scheduler

    async def _get_raw_components_from_future(self, future):
        """
        For a given future return the func, args and kwargs and future
        deps that would be executed remotely.
        """
        if isinstance(future, str):
            key = future
        else:
            await wait(future)
            key = future.key
        spec = await self.scheduler.get_runspec(key=key)
        deps, task = spec["deps"], spec["task"]
        if isinstance(task, dict):
            function, args, kwargs = _deserialize(**task)
            return (function, args, kwargs, deps)
        else:
            function, args, kwargs = _deserialize(task=task)
            return (function, args, kwargs, deps)

    async def _prepare_raw_components(self, raw_components):
        """
        Take raw components and resolve future dependencies.
        """
        function, args, kwargs, deps = raw_components
        futures = self.client._graph_to_futures({}, deps)
        data = await self.client._gather(futures)
        args = pack_data(args, data)
        kwargs = pack_data(kwargs, data)
        return (function, args, kwargs)

    async def _get_components_from_future(self, future):
        """
        For a given future return the func, args and kwargs that would be
        executed remotely. Any args/kwargs that are themselves futures will
        be resolved to the return value of those futures.
        """
        raw_components = await self._get_raw_components_from_future(future)
        return await self._prepare_raw_components(raw_components)

    def recreate_task_locally(self, future):
        """
        For any calculation, whether it succeeded or failed, perform the task
        locally for debugging.

        This operation should be performed after a future (result of ``gather``,
        ``compute``, etc) comes back with a status other than "pending". Cases
        where you might want to debug a successfully completed future could
        include a calculation that returns an unexpected results. A common
        debugging process might include running the task locally in debug mode,
        with `pdb.runcall`.

        Examples
        --------
        >>> import pdb                                    # doctest: +SKIP
        >>> future = c.submit(div, 1, 1)                  # doctest: +SKIP
        >>> future.status                                 # doctest: +SKIP
        'finished'
        >>> pdb.runcall(c.recreate_task_locally, future)  # doctest: +SKIP

        Parameters
        ----------
        future : future
            The same thing as was given to ``gather``.

        Returns
        -------
        Any; will return the result of the task future.
        """
        func, args, kwargs = sync(
            self.client.loop, self._get_components_from_future, future
        )
        return func(*args, **kwargs)

    async def _get_errored_future(self, future):
        """
        For a given future collection, return the first future that raised
        an error.
        """
        await wait(future)
        futures = [f.key for f in futures_of(future) if f.status == "error"]
        if not futures:
            raise ValueError("No errored futures passed")
        cause_key = await self.scheduler.get_error_cause(keys=futures)
        return cause_key

    def recreate_error_locally(self, future):
        """
        For a failed calculation, perform the blamed task locally for debugging.

        This operation should be performed after a future (result of ``gather``,
        ``compute``, etc) comes back with a status of "error", if the stack-
        trace is not informative enough to diagnose the problem. The specific
        task (part of the graph pointing to the future) responsible for the
        error will be fetched from the scheduler, together with the values of
        its inputs. The function will then be executed, so that ``pdb`` can
        be used for debugging.

        Examples
        --------
        >>> future = c.submit(div, 1, 0)         # doctest: +SKIP
        >>> future.status                        # doctest: +SKIP
        'error'
        >>> c.recreate_error_locally(future)     # doctest: +SKIP
        ZeroDivisionError: division by zero

        If you're in IPython you might take this opportunity to use pdb

        >>> %pdb                                 # doctest: +SKIP
        Automatic pdb calling has been turned ON

        >>> c.recreate_error_locally(future)     # doctest: +SKIP
        ZeroDivisionError: division by zero
              1 def div(x, y):
        ----> 2     return x / y
        ipdb>

        Parameters
        ----------
        future : future or collection that failed
            The same thing as was given to ``gather``, but came back with
            an exception/stack-trace. Can also be a (persisted) dask collection
            containing any errored futures.

        Returns
        -------
        Nothing; the function runs and should raise an exception, allowing
        the debugger to run.
        """
        errored_future_key = sync(self.client.loop, self._get_errored_future, future)
        return self.recreate_task_locally(errored_future_key)