File: spans.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (685 lines) | stat: -rw-r--r-- 23,767 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
from __future__ import annotations

import copy
import uuid
import weakref
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator, Mapping
from contextlib import contextmanager
from itertools import islice
from typing import TYPE_CHECKING, Any, TypedDict

import dask.config
from dask.typing import Key

from distributed.collections import sum_mappings
from distributed.itertools import ffill
from distributed.metrics import time

if TYPE_CHECKING:
    # Needed to avoid Sphinx WARNING: more than one target found for cross-reference
    # 'TaskState' and 'WorkerState'"
    # https://github.com/agronholm/sphinx-autodoc-typehints#dealing-with-circular-imports
    from distributed import Scheduler, Worker
    from distributed import scheduler as scheduler_module
    from distributed.client import SourceCode
    from distributed.scheduler import TaskGroup, TaskStateState


CONTEXTS_WITH_SPAN_ID = ("execute", "p2p")


class SpanMetadata(TypedDict):
    collections: list[dict]


@contextmanager
def span(*tags: str) -> Iterator[str]:
    """Tag group of tasks to be part of a certain group, called a span.

    This context manager can be nested, thus creating sub-spans. If you close and
    re-open a span context manager with the same tag, you'll end up with two separate
    spans.

    Every cluster defines a global "default" span when no span has been defined by the
    client; the default span is automatically closed and reopened when all tasks
    associated to it have been completed; in other words the cluster is idle save for
    tasks that are explicitly annotated by a span. Note that, in some edge cases, you
    may end up with overlapping default spans, e.g. if a worker crashes and all unique
    tasks that were in memory on it need to be recomputed.

    You may capture the ID of a span on the client to match it with the
    :class:`~distributed.spans.Span` objects the scheduler:

    >>> client = Client()
    >>> with span("my workflow") as span_id:
    ...     client.submit(lambda: "Hello world!").result()
    >>> client.cluster.scheduler.extensions["spans"].spans[span_id]
    Span<name=('my workflow',), id=5dc9b908-116b-49a5-b0d7-5a681f49a111>

    Notes
    -----
    You may retrieve the current span with ``dask.get_annotations().get("span")``.
    You can do so in the client code as well as from inside a task.
    """
    if not tags:
        raise ValueError("Must specify at least one span tag")

    annotation = dask.get_annotations().get("span")
    prev_tags = annotation["name"] if annotation else ()
    # You must specify the full history of IDs, not just the parent, because
    # otherwise you would not be able to uniquely identify grandparents when
    # they have no tasks of their own.
    prev_ids = annotation["ids"] if annotation else ()
    ids = tuple(str(uuid.uuid4()) for _ in tags)
    with dask.annotate(span={"name": prev_tags + tags, "ids": prev_ids + ids}):
        yield ids[-1]


class Span:
    #: (<tag>, <tag>, ...)
    #: Matches ``TaskState.annotations["span"]["name"]``, both on the scheduler and the
    #: worker.
    name: tuple[str, ...]

    #: Unique ID, generated by :func:`~distributed.span` and
    #: taken from ``TaskState.annotations["span"]["id"][-1]``.
    #: Matches ``distributed.scheduler.TaskState.group.span_id``
    #: and ``distributed.worker_state_machine.TaskState.span_id``.
    id: str

    _parent: weakref.ref[Span] | None

    #: Direct children of this span, sorted by creation time
    children: list[Span]

    #: Task groups *directly* belonging to this span.
    #:
    #: See Also
    #: --------
    #  traverse_groups
    #:
    #: Notes
    #: -----
    #: TaskGroups are forgotten by the Scheduler when the last task is forgotten, but
    #: remain referenced here indefinitely. If a user calls compute() twice on the same
    #: collection, you'll have more than one group with the same tg.name in this set!
    #: For the same reason, while the same TaskGroup object is guaranteed to be attached
    #: to exactly one Span, you may have different TaskGroups with the same key attached
    #: to different Spans.
    groups: set[TaskGroup]

    #: Time when the span first appeared on the scheduler.
    #: The same property on parent spans is always less than or equal to this.
    #:
    #: See Also
    #: --------
    #: start
    #: stop
    enqueued: float

    #: Source code snippets, if it was sent by the client.
    #: We're using a dict without values as an insertion-sorted set.
    _code: dict[tuple[SourceCode, ...], None]
    _metadata: SpanMetadata | None

    _cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float]

    #: reference to SchedulerState.total_nthreads_history
    _total_nthreads_history: list[tuple[float, int]]
    #: Length of total_nthreads_history when this span was enqueued
    _total_nthreads_offset: int

    # Support for weakrefs to a class with __slots__
    __weakref__: Any

    __slots__ = tuple(__annotations__)
    _metadata_seen: set[int] = set()

    def __init__(
        self,
        name: tuple[str, ...],
        id_: str,
        parent: Span | None,
        total_nthreads_history: list[tuple[float, int]],
    ):
        self.name = name
        self.id = id_
        self._parent = weakref.ref(parent) if parent is not None else None
        self.enqueued = time()
        self.children = []
        self.groups = set()
        self._code = {}
        self._metadata = None

        # Don't cast int metrics to float
        self._cumulative_worker_metrics = defaultdict(int)

        assert len(total_nthreads_history) > 0
        self._total_nthreads_history = total_nthreads_history
        self._total_nthreads_offset = len(total_nthreads_history) - 1

    def __repr__(self) -> str:
        return f"Span<name={self.name}, id={self.id}>"

    @property
    def parent(self) -> Span | None:
        if self._parent:
            out = self._parent()
            assert out
            return out
        return None

    def add_metadata(self, metadata: SpanMetadata) -> None:
        """Add metadata to the span, e.g. code snippets"""
        id_ = id(metadata)
        if id_ in self._metadata_seen:
            return
        self._metadata_seen.add(id_)
        if self._metadata is None:
            self._metadata = copy.deepcopy(metadata)
        else:
            self._metadata["collections"].extend(metadata["collections"])

    @property
    def annotation(self) -> dict[str, tuple[str, ...]] | None:
        """Rebuild the dask graph annotation which contains the full id history

        Note that this may not match the original annotation in case of TaskGroup
        collision.
        """
        if self.name == ("default",):
            return None
        ids = []
        node: Span | None = self
        while node:
            ids.append(node.id)
            node = node.parent
        return {"name": self.name, "ids": tuple(reversed(ids))}

    def traverse_spans(self) -> Iterator[Span]:
        """Top-down recursion of all spans belonging to this branch off span tree,
        including self
        """
        yield self
        for child in self.children:
            yield from child.traverse_spans()

    def traverse_groups(self) -> Iterator[TaskGroup]:
        """All TaskGroups belonging to this branch of span tree"""
        for span in self.traverse_spans():
            yield from span.groups

    @property
    def start(self) -> float:
        """Earliest time when a task belonging to this span tree started computing;
        0 if no task has *finished* computing yet.

        Notes
        -----
        This is not updated until at least one task has *finished* computing.
        It could move backwards as tasks complete.

        See Also
        --------
        enqueued
        stop
        distributed.scheduler.TaskGroup.start
        """
        out = min(
            (tg.start for tg in self.traverse_groups() if tg.start != 0.0),
            default=0.0,
        )
        if out:
            # absorb small errors in worker delay calculation
            out = max(out, self.enqueued)
        return out

    @property
    def stop(self) -> float:
        """When this span tree finished computing, or current timestamp if it didn't
        finish yet.

        Notes
        -----
        This differs from ``TaskGroup.stop`` when there aren't unfinished tasks; is also
        will never be zero.

        See Also
        --------
        enqueued
        start
        done
        distributed.scheduler.TaskGroup.stop
        """
        if self.done:
            out = max(tg.stop for tg in self.traverse_groups())
        else:
            out = time()
        # absorb small errors in worker delay calculation, as well as in time() not
        # being perfectly monotonic
        return max(out, self.enqueued)

    @property
    def metadata(self) -> SpanMetadata | None:
        return self._metadata

    @property
    def states(self) -> dict[TaskStateState, int]:
        """The number of tasks currently in each state in this span tree;
        e.g. ``{"memory": 10, "processing": 3, "released": 4, ...}``.

        See Also
        --------
        distributed.scheduler.TaskGroup.states
        """
        return sum_mappings(tg.states for tg in self.traverse_groups())

    @property
    def done(self) -> bool:
        """Return True if all tasks in this span tree are completed; False otherwise.

        Notes
        -----
        This property may transition from True to False, e.g. when a new sub-span is
        added or when a worker that contained the only replica of a task in memory
        crashes and the task need to be recomputed.

        See Also
        --------
        distributed.scheduler.TaskGroup.done
        """
        return all(tg.done for tg in self.traverse_groups())

    @property
    def all_durations(self) -> dict[str, float]:
        """Cumulative duration of all completed actions in this span tree, by action

        See Also
        --------
        duration
        distributed.scheduler.TaskGroup.all_durations
        """
        return sum_mappings(tg.all_durations for tg in self.traverse_groups())

    @property
    def duration(self) -> float:
        """The total amount of time spent on all tasks in this span tree

        See Also
        --------
        all_durations
        distributed.scheduler.TaskGroup.duration
        """
        return sum(tg.duration for tg in self.traverse_groups())

    @property
    def nbytes_total(self) -> int:
        """The total number of bytes that this span tree has produced

        See Also
        --------
        distributed.scheduler.TaskGroup.nbytes_total
        """
        return sum(tg.nbytes_total for tg in self.traverse_groups())

    @property
    def code(self) -> list[tuple[SourceCode, ...]]:
        """Code snippets, sent by the client on compute(), persist(), and submit().

        Only populated if ``distributed.diagnostics.computations.nframes`` is non-zero.
        """
        # Deduplicate, but preserve order
        return list(
            dict.fromkeys(sc for child in self.traverse_spans() for sc in child._code)
        )

    @property
    def cumulative_worker_metrics(self) -> dict[tuple[Hashable, ...], float]:
        """Replica of ``Worker.digests_total`` and
        ``Scheduler.cumulative_worker_metrics``, but only for the metrics that can be
        attributed to the current span tree. The span id has been removed from the key.

        At the moment of writing, all keys are
        ``("execute", <task prefix>, <activity>, <unit>)``
        or
        ``("p2p", <where>, <activity>, <unit>)``
        but more may be added in the future with a different format; please test e.g.
        for ``k[0] == "execute"``.
        """
        out = sum_mappings(
            child._cumulative_worker_metrics for child in self.traverse_spans()
        )
        known_seconds = sum(
            v for k, v in out.items() if k[0] == "execute" and k[-1] == "seconds"
        )
        # Besides rounding errors, you may get negative unknown seconds if a user
        # manually invokes `context_meter.digest_metric`.
        unknown_seconds = max(0.0, self.active_cpu_seconds - known_seconds)

        out["execute", "N/A", "idle or other spans", "seconds"] = unknown_seconds
        return out

    @staticmethod
    def merge(*items: Span) -> Span:
        """Merge multiple spans into a synthetic one.
        The input spans must not be related with each other.
        """
        if not items:
            raise ValueError("Nothing to merge")
        out = Span(
            name=("(merged)",),
            id_="(merged)",
            parent=None,
            total_nthreads_history=items[0]._total_nthreads_history,
        )
        out._total_nthreads_offset = min(
            child._total_nthreads_offset for child in items
        )
        out.children.extend(items)
        out.enqueued = min(child.enqueued for child in items)
        return out

    def _nthreads_timeseries(self) -> Iterator[tuple[float, int]]:
        """Yield (timestamp, number of threads across the cluster), forward-fill"""
        stop = self.stop if self.done else 0
        for t, n in islice(
            self._total_nthreads_history, self._total_nthreads_offset, None
        ):
            if stop and t >= stop:
                break
            yield max(self.enqueued, t), n

    def _active_timeseries(self) -> Iterator[tuple[float, bool]]:
        """If this span is the output of :meth:`merge`, yield
        (timestamp, True if at least one input span is active), forward-fill.
        """
        if self.id != "(merged)":
            yield self.enqueued, True
            yield self.stop, False
            return

        events = []
        for child in self.children:
            events += [(child.enqueued, 1), (child.stop, -1)]
        # enqueued <= stop by construction.
        # Occasionally, enqueued == stop, e.g. when the clock is adjusted backwards.
        # Prevent negative n_active when this happens.
        events.sort(key=lambda el: el[0])

        n_active = 0
        for t, delta in events:
            if not n_active:
                assert delta == 1
                yield t, True
            n_active += delta
            if not n_active:
                yield t, False

    @property
    def nthreads_intervals(self) -> list[tuple[float, float, int]]:
        """
        Returns
        -------
        List of tuples:

        - begin timestamp
        - end timestamp
        - Scheduler.total_nthreads during this interval

        When the Span is the output of :meth:`merge`, the intervals may not be
        contiguous.

        See Also
        --------
        enqueued
        stop
        active_cpu_seconds
        distributed.scheduler.SchedulerState.total_nthreads
        """
        nthreads_t, nthreads_count = zip(*self._nthreads_timeseries())
        is_active_t, is_active_flag = zip(*self._active_timeseries())
        t_interp = sorted({*nthreads_t, *is_active_t})
        nthreads_count_interp = ffill(t_interp, nthreads_t, nthreads_count, left=0)
        is_active_flag_interp = ffill(t_interp, is_active_t, is_active_flag, left=False)
        return [
            (t0, t1, n)
            for t0, t1, n, active in zip(
                t_interp, t_interp[1:], nthreads_count_interp, is_active_flag_interp
            )
            if active
        ]

    @property
    def active_cpu_seconds(self) -> float:
        """Return number of CPU seconds that were made available on the cluster while
        this Span was running; in other words
        ``(Span.stop - Span.enqueued) * Scheduler.total_nthreads``.

        This accounts for workers joining and leaving the cluster while this Span was
        active. If this Span is the output of :meth:`merge`, do not count gaps between
        input spans.

        See Also
        --------
        enqueued
        stop
        nthreads_intervals
        distributed.scheduler.SchedulerState.total_nthreads
        """
        return sum((t1 - t0) * nthreads for t0, t1, nthreads in self.nthreads_intervals)


class SpansSchedulerExtension:
    """Scheduler extension for spans support"""

    scheduler: Scheduler

    #: All Span objects by id
    spans: dict[str, Span]

    #: Only the spans that don't have any parents, sorted by creation time.
    #: This is a convenience helper structure to speed up searches.
    root_spans: list[Span]

    #: All spans, keyed by their full name and sorted by creation time.
    #: This is a convenience helper structure to speed up searches.
    spans_search_by_name: defaultdict[tuple[str, ...], list[Span]]

    #: All spans, keyed by the individual tags that make up their name and sorted by
    #: creation time.
    #: This is a convenience helper structure to speed up searches.
    #:
    #: See Also
    #: --------
    #: find_by_tags
    #: merge_by_tags
    spans_search_by_tag: defaultdict[str, list[Span]]

    def __init__(self, scheduler: Scheduler):
        self.scheduler = scheduler
        self.spans = {}
        self.root_spans = []
        self.spans_search_by_name = defaultdict(list)
        self.spans_search_by_tag = defaultdict(list)

    def observe_tasks(
        self,
        tss: Iterable[scheduler_module.TaskState],
        code: tuple[SourceCode, ...],
        span_metadata: SpanMetadata,
    ) -> dict[Key, dict]:
        """Acknowledge the existence of runnable tasks on the scheduler. These may
        either be new tasks, tasks that were previously unrunnable, or tasks that were
        already fed into this method already.

        Attach newly observed tasks to either the desired span or to ("default", ).
        Update TaskGroup.span_id and wipe TaskState.annotations["span"].

        Returns
        -------
        Updated 'span' annotations: {key: {"name": (..., ...), "ids": (..., ...)}}
        """
        out = {}
        default_span = None

        for ts in tss:
            if ts.annotations is None:
                ts.annotations = dict()
            # You may have different tasks belonging to the same TaskGroup but to
            # different spans. If that happens, arbitrarily force everything onto the
            # span of the earliest encountered TaskGroup.
            tg = ts.group
            if tg.span_id:
                span = self.spans[tg.span_id]
            else:
                ann = ts.annotations.get("span")
                if ann:
                    span = self._ensure_span(ann["name"], ann["ids"])
                else:
                    if not default_span:
                        default_span = self._ensure_default_span()
                    span = default_span

                tg.span_id = span.id
                span.groups.add(tg)

            if code:
                span._code[code] = None
            if span_metadata:
                span.add_metadata(span_metadata)

            # The span may be completely different from the one referenced by the
            # annotation, due to the TaskGroup collision issue explained above.
            if ann := span.annotation:
                ts.annotations["span"] = out[ts.key] = ann
            else:
                ts.annotations.pop("span", None)

        return out

    def _ensure_default_span(self) -> Span:
        """Return the currently active default span, or create one if the previous one
        terminated. In other words, do not reuse the previous default span if all tasks
        that were not explicitly annotated with :func:`spans` on the client side are
        finished.
        """
        defaults = self.spans_search_by_name["default",]
        if defaults and not defaults[-1].done:
            return defaults[-1]
        return self._ensure_span(("default",), (str(uuid.uuid4()),))

    def _ensure_span(self, name: tuple[str, ...], ids: tuple[str, ...]) -> Span:
        """Create Span if it doesn't exist and return it"""
        try:
            return self.spans[ids[-1]]
        except KeyError:
            pass

        assert len(name) == len(ids)
        assert len(name) > 0

        parent = None
        for i in range(1, len(name)):
            parent = self._ensure_span(name[:i], ids[:i])

        span = Span(
            name=name,
            id_=ids[-1],
            parent=parent,
            total_nthreads_history=self.scheduler.total_nthreads_history,
        )
        self.spans[span.id] = span
        self.spans_search_by_name[name].append(span)
        for tag in name:
            self.spans_search_by_tag[tag].append(span)
        if parent:
            parent.children.append(span)
        else:
            self.root_spans.append(span)

        return span

    def find_by_tags(self, *tags: str) -> Iterator[Span]:
        """Yield all spans that contain any of the given tags.
        When a tag is shared both by a span and its (grand)children, only return the
        parent.
        """
        by_level = defaultdict(list)
        for tag in tags:
            for sp in self.spans_search_by_tag[tag]:
                by_level[len(sp.name)].append(sp)

        seen = set()
        for _, level in sorted(by_level.items()):
            seen.update(level)
            for sp in level:
                if sp.parent not in seen:
                    yield sp

    def merge_all(self) -> Span:
        """Return a synthetic Span which is the sum of all spans"""
        return Span.merge(*self.root_spans)

    def merge_by_tags(self, *tags: str) -> Span:
        """Return a synthetic Span which is the sum of all spans containing the given
        tags
        """
        return Span.merge(*self.find_by_tags(*tags))

    def heartbeat(
        self, ws: scheduler_module.WorkerState, data: dict[tuple[Hashable, ...], float]
    ) -> None:
        """Triggered by :meth:`SpansWorkerExtension.heartbeat`.

        Populate :meth:`Span.cumulative_worker_metrics` with data from the worker.

        See Also
        --------
        SpansWorkerExtension.heartbeat
        Span.cumulative_worker_metrics
        """
        for (context, span_id, *other), v in data.items():
            assert isinstance(span_id, str)
            span = self.spans[span_id]
            span._cumulative_worker_metrics[(context, *other)] += v


class SpansWorkerExtension:
    """Worker extension for spans support"""

    worker: Worker
    digests_total_since_heartbeat: dict[tuple[Hashable, ...], float]

    def __init__(self, worker: Worker):
        self.worker = worker
        self.digests_total_since_heartbeat = {}

    def collect_digests(
        self, digests_total_since_heartbeat: Mapping[Hashable, float]
    ) -> None:
        # Note: this method may be called spuriously by Worker._register_with_scheduler,
        # but when it does it's guaranteed not to find any metrics
        assert not self.digests_total_since_heartbeat
        self.digests_total_since_heartbeat = {
            k: v
            for k, v in digests_total_since_heartbeat.items()
            if isinstance(k, tuple) and k[0] in CONTEXTS_WITH_SPAN_ID
        }

    def heartbeat(self) -> dict[tuple[Hashable, ...], float]:
        """Apportion the metrics that do have a span to the Spans on the scheduler

        Returns
        -------
        ``{(context, span_id, prefix, activity, unit): value}}``

        See Also
        --------
        SpansSchedulerExtension.heartbeat
        Span.cumulative_worker_metrics
        distributed.worker.Worker.get_metrics
        """
        out = self.digests_total_since_heartbeat
        self.digests_total_since_heartbeat = {}
        return out