File: _async_taskgroup.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (159 lines) | stat: -rw-r--r-- 4,587 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations

import asyncio
import threading
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, TypeVar

if TYPE_CHECKING:
    from typing_extensions import ParamSpec

    P = ParamSpec("P")
    R = TypeVar("R")
    T = TypeVar("T")
    Coro = Coroutine[Any, Any, T]


class _LoopBoundMixin:
    """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11"""

    _global_lock = threading.Lock()

    _loop = None

    def _get_loop(self):
        loop = asyncio.get_running_loop()

        if self._loop is None:
            with self._global_lock:
                if self._loop is None:
                    self._loop = loop
        if loop is not self._loop:
            raise RuntimeError(f"{self!r} is bound to a different event loop")
        return loop


class AsyncTaskGroupClosedError(RuntimeError):
    pass


def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]:
    """Decorator to delay the evaluation of a coroutine function by the given delay in seconds."""

    async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        await asyncio.sleep(delay)
        return await corofunc(*args, **kwargs)

    return wrapper


class AsyncTaskGroup(_LoopBoundMixin):
    """Collection tracking all currently running asynchronous tasks within a group"""

    #: If True, the group is closed and does not allow adding new tasks.
    closed: bool

    def __init__(self) -> None:
        self.closed = False
        self._ongoing_tasks: set[asyncio.Task[None]] = set()

    def call_soon(
        self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs
    ) -> None:
        """Schedule a coroutine function to be executed as an `asyncio.Task`.

        The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments
        as an `asyncio.Task`.

        Parameters
        ----------
        afunc
            Coroutine function to schedule.
        *args
            Arguments to be passed to `afunc`.
        **kwargs
            Keyword arguments to be passed to `afunc`

        Returns
        -------
            None

        Raises
        ------
        AsyncTaskGroupClosedError
            If the task group is closed.
        """
        if self.closed:  # Avoid creating a coroutine
            raise AsyncTaskGroupClosedError(
                "Cannot schedule a new coroutine function as the group is already closed."
            )
        task = self._get_loop().create_task(afunc(*args, **kwargs))
        task.add_done_callback(self._ongoing_tasks.remove)
        self._ongoing_tasks.add(task)
        return None

    def call_later(
        self,
        delay: float,
        afunc: Callable[P, Coro[None]],
        /,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> None:
        """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`.

        The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments
        as an `asyncio.Task` that is executed after `delay` seconds.

        Parameters
        ----------
        delay
            Delay in seconds.
        afunc
            Coroutine function to schedule.
        *args
            Arguments to be passed to `afunc`.
        **kwargs
            Keyword arguments to be passed to `afunc`

        Returns
        -------
            The None

        Raises
        ------
        AsyncTaskGroupClosedError
            If the task group is closed.
        """
        self.call_soon(_delayed(afunc, delay), *args, **kwargs)

    def close(self) -> None:
        """Closes the task group so that no new tasks can be scheduled.

        Existing tasks continue to run.
        """
        self.closed = True

    async def stop(self) -> None:
        """Close the group and stop all currently running tasks.

        Closes the task group and cancels all tasks. All tasks are cancelled
        an additional time for each time this task is cancelled.
        """
        self.close()

        current_task = asyncio.current_task(self._get_loop())
        err = None
        while tasks_to_stop := (self._ongoing_tasks - {current_task}):
            for task in tasks_to_stop:
                task.cancel()
            try:
                await asyncio.wait(tasks_to_stop)
            except asyncio.CancelledError as e:
                err = e

        if err is not None:
            raise err

    def __len__(self):
        return len(self._ongoing_tasks)