1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
from __future__ import annotations
import uuid
from distributed.diagnostics.plugin import SchedulerPlugin
class GraphLayout(SchedulerPlugin):
"""Dynamic graph layout during computation
This assigns (x, y) locations to all tasks quickly and dynamically as new
tasks are added. This scales to a few thousand nodes.
It is commonly used with distributed/dashboard/components/scheduler.py::TaskGraph, which
is rendered at /graph on the diagnostic dashboard.
"""
def __init__(self, scheduler):
self.name = f"graph-layout-{uuid.uuid4()}"
self.x = {}
self.y = {}
self.collision = {}
self.scheduler = scheduler
self.index = {}
self.index_edge = {}
self.next_y = 0
self.next_index = 0
self.next_edge_index = 0
self.new = []
self.new_edges = []
self.state_updates = []
self.visible_updates = []
self.visible_edge_updates = []
if self.scheduler.tasks:
dependencies = {
k: [ds.key for ds in ts.dependencies]
for k, ts in scheduler.tasks.items()
}
priority = {k: ts.priority for k, ts in scheduler.tasks.items()}
self.update_graph(
self.scheduler,
tasks=self.scheduler.tasks,
dependencies=dependencies,
priority=priority,
)
def update_graph(
self, scheduler, dependencies=None, priority=None, tasks=None, **kwargs
):
stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True)
while stack:
key = stack.pop()
if key in self.x or key not in scheduler.tasks:
continue
deps = dependencies.get(key, ())
if deps:
if not all(dep in self.y for dep in deps):
stack.append(key)
stack.extend(
sorted(deps, key=lambda k: priority.get(k, 0), reverse=True)
)
continue
else:
total_deps = sum(
len(scheduler.tasks[dep].dependents) for dep in deps
)
y = sum(
self.y[dep] * len(scheduler.tasks[dep].dependents) / total_deps
for dep in deps
)
x = max(self.x[dep] for dep in deps) + 1
else:
x = 0
y = self.next_y
self.next_y += 1
if (x, y) in self.collision:
old_x, old_y = x, y
x, y = self.collision[(x, y)]
y += 0.1
self.collision[old_x, old_y] = (x, y)
else:
self.collision[(x, y)] = (x, y)
self.x[key] = x
self.y[key] = y
self.index[key] = self.next_index
self.next_index = self.next_index + 1
self.new.append(key)
for dep in deps:
edge = (dep, key)
self.index_edge[edge] = self.next_edge_index
self.next_edge_index += 1
self.new_edges.append(edge)
def transition(self, key, start, finish, *args, **kwargs):
if finish != "forgotten":
self.state_updates.append((self.index[key], finish))
else:
self.visible_updates.append((self.index[key], "False"))
task = self.scheduler.tasks[key]
for dep in task.dependents:
edge = (key, dep.key)
self.visible_edge_updates.append((self.index_edge.pop(edge), "False"))
for dep in task.dependencies:
self.visible_edge_updates.append(
(self.index_edge.pop((dep.key, key)), "False")
)
try:
del self.collision[(self.x[key], self.y[key])]
except KeyError:
pass
for collection in [self.x, self.y, self.index]:
del collection[key]
def reset_index(self):
"""Reset the index and refill new and new_edges
From time to time TaskGraph wants to remove invisible nodes and reset
all of its indices. This helps.
"""
self.new = []
self.new_edges = []
self.visible_updates = []
self.state_updates = []
self.visible_edge_updates = []
self.index = {}
self.next_index = 0
self.index_edge = {}
self.next_edge_index = 0
for key in self.x:
self.index[key] = self.next_index
self.next_index += 1
self.new.append(key)
for dep in self.scheduler.tasks[key].dependencies:
edge = (dep.key, key)
self.index_edge[edge] = self.next_edge_index
self.next_edge_index += 1
self.new_edges.append(edge)
|