File: test_reschedule.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (140 lines) | stat: -rw-r--r-- 4,380 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Tests for tasks raising the Reschedule exception and Scheduler._reschedule().

Note that this functionality is also used by work stealing;
see test_steal.py for additional tests.
"""
from __future__ import annotations

import asyncio
from time import sleep

import pytest

from distributed import Event, Reschedule, get_worker, secede, wait
from distributed.utils_test import captured_logger, gen_cluster, slowinc
from distributed.worker_state_machine import (
    ComputeTaskEvent,
    FreeKeysEvent,
    GatherDep,
    RescheduleEvent,
    RescheduleMsg,
)


@gen_cluster()
async def test_scheduler_reschedule_warns(s, a, b):
    with captured_logger("distributed.scheduler") as sched:
        s._reschedule(key="__this-key-does-not-exist__", stimulus_id="test")

    assert "not found on the scheduler" in sched.getvalue()
    assert "Aborting reschedule" in sched.getvalue()


@pytest.mark.parametrize("state", ["executing", "long-running"])
@gen_cluster(
    client=True,
    nthreads=[("", 1)] * 2,
    config={"distributed.scheduler.work-stealing": False},
)
async def test_raise_reschedule(c, s, a, b, state):
    """A task raises Reschedule()"""
    a_address = a.address

    def f(x):
        if state == "long-running":
            secede()
        sleep(0.1)
        if get_worker().address == a_address:
            raise Reschedule()

    futures = c.map(f, range(4), key=["x1", "x2", "x3", "x4"])
    futures2 = c.map(slowinc, range(10), delay=0.1, key="clog", workers=[a.address])
    await wait(futures)
    assert any(isinstance(ev, RescheduleEvent) for ev in a.state.stimulus_log)
    assert all(f.key in b.data for f in futures)
    assert "x" not in a.state.tasks


@pytest.mark.parametrize("state", ["executing", "long-running"])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_cancelled_reschedule(c, s, a, state):
    """A task raises Reschedule(), but the future was released by the client.
    Same as test_cancelled_reschedule_worker_state"""
    ev1 = Event()
    ev2 = Event()

    def f(ev1, ev2):
        if state == "long-running":
            secede()
        ev1.set()
        ev2.wait()
        raise Reschedule()

    x = c.submit(f, ev1, ev2, key="x")
    await ev1.wait()
    x.release()
    while "x" in s.tasks:
        await asyncio.sleep(0.01)

    await ev2.set()
    while "x" in a.state.tasks:
        await asyncio.sleep(0.01)


def test_cancelled_reschedule_worker_state(ws_with_running_task):
    """Same as test_cancelled_reschedule"""
    ws = ws_with_running_task

    instructions = ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
    assert not instructions
    assert ws.tasks["x"].state == "cancelled"
    assert ws.available_resources == {"R": 0}

    instructions = ws.handle_stimulus(RescheduleEvent(key="x", stimulus_id="s2"))
    assert not instructions  # There's no RescheduleMsg
    assert not ws.tasks  # The task has been forgotten
    assert ws.available_resources == {"R": 1}


def test_reschedule_releases(ws_with_running_task):
    ws = ws_with_running_task

    instructions = ws.handle_stimulus(RescheduleEvent(key="x", stimulus_id="s1"))
    assert instructions == [RescheduleMsg(stimulus_id="s1", key="x")]
    assert ws.available_resources == {"R": 1}
    assert "x" not in ws.tasks


def test_reschedule_cancelled(ws_with_running_task):
    """Test state loop:

    executing -> cancelled -> rescheduled
    executing -> long-running -> cancelled -> rescheduled
    """
    ws = ws_with_running_task
    instructions = ws.handle_stimulus(
        FreeKeysEvent(keys=["x"], stimulus_id="s1"),
        RescheduleEvent(key="x", stimulus_id="s2"),
    )
    assert not instructions
    assert "x" not in ws.tasks


def test_reschedule_resumed(ws_with_running_task):
    """Test state loop:

    executing -> cancelled -> resumed(fetch) -> rescheduled
    executing -> long-running -> cancelled -> resumed(fetch) -> rescheduled
    """
    ws = ws_with_running_task
    ws2 = "127.0.0.1:2"

    instructions = ws.handle_stimulus(
        FreeKeysEvent(keys=["x"], stimulus_id="s1"),
        ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
        RescheduleEvent(key="x", stimulus_id="s3"),
    )
    assert instructions == [
        GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s3")
    ]
    assert ws.tasks["x"].state == "flight"