File: dask.py

package info (click to toggle)
python-eliot 1.16.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 964 kB
  • sloc: python: 8,641; makefile: 151
file content (172 lines) | stat: -rw-r--r-- 5,440 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Support for Eliot tracing with Dask computations."""

from pyrsistent import PClass, field

from dask import compute, optimize, persist

try:
    from dask.distributed import Future
    from dask.highlevelgraph import HighLevelGraph
except:

    class Future(object):
        pass


from dask.core import toposort, get_dependencies, ishashable
from . import start_action, current_action, Action


class _RunWithEliotContext(PClass):
    """
    Run a callable within an Eliot context.

    @ivar task_id: The serialized Eliot task ID.
    @ivar func: The function that Dask wants to run.
    @ivar key: The key in the Dask graph.
    @ivar dependencies: The keys in the Dask graph this depends on.
    """

    task_id = field(type=str)
    func = field()  # callable
    key = field(type=str)
    dependencies = field()

    # Pretend to be underlying callable for purposes of equality; necessary for
    # optimizer to be happy:

    def __eq__(self, other):
        return self.func == other

    def __ne__(self, other):
        return self.func != other

    def __hash__(self):
        return hash(self.func)

    def __call__(self, *args, **kwargs):
        with Action.continue_task(task_id=self.task_id) as action:
            action.log(
                message_type="dask:task", key=self.key, dependencies=self.dependencies
            )
            return self.func(*args, **kwargs)


def compute_with_trace(*args):
    """Do Dask compute(), but with added Eliot tracing.

    Dask is a graph of tasks, but Eliot logs trees.  So we need to emulate a
    graph using a tree.  We do this by making Eliot action for each task, but
    having it list the tasks it depends on.

    We use the following algorithm:

        1. Create a top-level action.

        2. For each entry in the dask graph, create a child with
           serialize_task_id.  Do this in likely order of execution, so that
           if B depends on A the task level of B is higher than the task Ievel
           of A.

        3. Replace each function with a wrapper that uses the corresponding
           task ID (with Action.continue_task), and while it's at it also
           records which other things this function depends on.

    Known issues:

        1. Retries will confuse Eliot.  Probably need different
           distributed-tree mechanism within Eliot to solve that.
    """
    # 1. Create top-level Eliot Action:
    with start_action(action_type="dask:compute"):
        # In order to reduce logging verbosity, add logging to the already
        # optimized graph:
        optimized = optimize(*args, optimizations=[_add_logging])
        return compute(*optimized, optimize_graph=False)


def persist_with_trace(*args):
    """Do Dask persist(), but with added Eliot tracing.

    Known issues:

        1. Retries will confuse Eliot.  Probably need different
           distributed-tree mechanism within Eliot to solve that.
    """
    # 1. Create top-level Eliot Action:
    with start_action(action_type="dask:persist"):
        # In order to reduce logging verbosity, add logging to the already
        # optimized graph:
        optimized = optimize(*args, optimizations=[_add_logging])
        return persist(*optimized, optimize_graph=False)


def _add_logging(dsk, ignore=None):
    """
    Add logging to a Dask graph.

    @param dsk: The Dask graph.

    @return: New Dask graph.
    """
    if isinstance(dsk, HighLevelGraph):
        dsk = dsk.to_dict()

    ctx = current_action()
    result = {}

    # Use topological sort to ensure Eliot actions are in logical order of
    # execution in Dask:
    keys = toposort(dsk)

    # Give each key a string name. Some keys are just aliases to other
    # keys, so make sure we have underlying key available. Later on might
    # want to shorten them as well.
    def simplify(k):
        if isinstance(k, str):
            return k
        return "-".join(str(o) for o in k)

    key_names = {}
    for key in keys:
        value = dsk[key]
        if not callable(value) and ishashable(value) and value in keys:
            # It's an alias for another key:
            key_names[key] = key_names[value]
        else:
            key_names[key] = simplify(key)

    # Values in the graph can be either:
    #
    # 1. A list of other values.
    # 2. A tuple, where first value might be a callable, aka a task.
    # 3. A literal of some sort.
    def maybe_wrap(key, value):
        if isinstance(value, list):
            return [maybe_wrap(key, v) for v in value]
        elif isinstance(value, tuple):
            func = value[0]
            args = value[1:]
            if not callable(func):
                # Not a callable, so nothing to wrap.
                return value
            wrapped_func = _RunWithEliotContext(
                task_id=str(ctx.serialize_task_id(), "utf-8"),
                func=func,
                key=key_names[key],
                dependencies=[key_names[k] for k in get_dependencies(dsk, key)],
            )
            return (wrapped_func,) + args
        else:
            return value

    # Replace function with wrapper that logs appropriate Action; iterate in
    # topological order so action task levels are in reasonable order.
    for key in keys:
        result[key] = maybe_wrap(key, dsk[key])

    assert set(result.keys()) == set(dsk.keys())
    return result


__all__ = ["compute_with_trace", "persist_with_trace"]