File: test_stories.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (174 lines) | stat: -rw-r--r-- 5,062 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from __future__ import annotations

import pytest

import dask

from distributed import Worker
from distributed.comm import CommClosedError
from distributed.utils_test import (
    NO_AMM,
    assert_story,
    assert_valid_story,
    gen_cluster,
    inc,
)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_scheduler_story_stimulus_success(c, s, a):
    f = c.submit(inc, 1)
    key = f.key

    await f

    stories = s.story(key)

    stimulus_ids = {s[-2] for s in stories}
    # Two events
    # - Compute
    # - Success
    assert len(stimulus_ids) == 2
    assert_story(
        stories,
        [
            (key, "released", "waiting", {key: "processing"}),
            (key, "waiting", "processing", {}),
            (key, "processing", "memory", {}),
        ],
    )

    await c.close()

    stories_after_close = s.story(key)
    assert len(stories_after_close) > len(stories)

    stimulus_ids = {s[-2] for s in stories_after_close}
    # One more event
    # - Forget / Release / Free since client closed
    assert len(stimulus_ids) == 3


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_scheduler_story_stimulus_retry(c, s, a):
    def task():
        assert dask.config.get("foo")

    with dask.config.set(foo=False):
        f = c.submit(task)
        with pytest.raises(AssertionError):
            await f

    with dask.config.set(foo=True):
        await f.retry()
        await f

    story = s.story(f.key)
    stimulus_ids = {s[-2] for s in story}
    # Four events
    # - Compute
    # - Erred
    # - Compute / Retry
    # - Success
    assert len(stimulus_ids) == 4

    assert_story(
        story,
        [
            # Erred transitions via released
            (f.key, "processing", "erred", {}),
            (f.key, "erred", "released", {}),
            (f.key, "released", "waiting", {f.key: "processing"}),
        ],
    )


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_client_story(c, s, a):
    f = c.submit(inc, 1)
    assert await f == 2
    story = await c.story(f.key)

    # Every event should be prefixed with it's origin
    # This changes the format compared to default scheduler / worker stories
    prefixes = set()
    stripped_story = list()
    for msg in story:
        prefixes.add(msg[0])
        stripped_story.append(msg[1:])
    assert prefixes == {"scheduler", a.address}

    assert_valid_story(stripped_story, ordered_timestamps=False)

    # If it's a well formed story, we can sort by the last element which is a
    # timestamp and compare the two lists.
    assert sorted(stripped_story, key=lambda msg: msg[-1]) == sorted(
        s.story(f.key) + a.state.story(f.key), key=lambda msg: msg[-1]
    )


class WorkerBrokenStory(Worker):
    async def get_story(self, *args, **kw):
        raise CommClosedError


@gen_cluster(client=True, Worker=WorkerBrokenStory)
@pytest.mark.parametrize("on_error", ["ignore", "raise"])
async def test_client_story_failed_worker(c, s, a, b, on_error):
    f = c.submit(inc, 1)
    coro = c.story(f.key, on_error=on_error)
    await f

    if on_error == "raise":
        with pytest.raises(CommClosedError):
            await coro
    elif on_error == "ignore":
        story = await coro
        assert story
        assert len(story) > 1
    else:
        raise ValueError(on_error)


@gen_cluster(client=True, config=NO_AMM)
async def test_worker_story_with_deps(c, s, a, b):
    """
    Assert that the structure of the story does not change unintentionally and
    expected subfields are actually filled
    """
    dep = c.submit(inc, 1, workers=[a.address], key="dep")
    res = c.submit(inc, dep, workers=[b.address], key="res")
    await res

    story = a.state.story("res")
    assert story == []

    # Story now includes randomized stimulus_ids and timestamps.
    story = b.state.story("res")
    stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
    assert stimulus_ids == {"compute-task", "gather-dep-success", "task-finished"}
    # This is a simple transition log
    expected = [
        ("res", "compute-task", "released"),
        ("res", "released", "waiting", "waiting", {"dep": "fetch"}),
        ("res", "waiting", "ready", "ready", {"res": "executing"}),
        ("res", "ready", "executing", "executing", {}),
        ("res", "put-in-memory"),
        ("res", "executing", "memory", "memory", {}),
    ]
    assert_story(story, expected, strict=True)

    story = b.state.story("dep")
    stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
    assert stimulus_ids == {"compute-task", "gather-dep-success"}
    expected = [
        ("dep", "ensure-task-exists", "released"),
        ("dep", "released", "fetch", "fetch", {}),
        ("gather-dependencies", a.address, {"dep"}),
        ("dep", "fetch", "flight", "flight", {}),
        ("request-dep", a.address, {"dep"}),
        ("receive-dep", a.address, {"dep"}),
        ("dep", "put-in-memory"),
        ("dep", "flight", "memory", "memory", {"res": "ready"}),
    ]
    assert_story(story, expected, strict=True)